Skip to content

feat(moe): add moe support and fused topk & moe kernels#37

Open
MikanAffine wants to merge 6 commits intoSJTU-DENG-Lab:mainfrom
MikanAffine:fusedmoe
Open

feat(moe): add moe support and fused topk & moe kernels#37
MikanAffine wants to merge 6 commits intoSJTU-DENG-Lab:mainfrom
MikanAffine:fusedmoe

Conversation

@MikanAffine
Copy link
Copy Markdown

@MikanAffine MikanAffine commented Mar 31, 2026

Feature:

  • add MoE support to SDAR-MOE
  • add fused TopK and fused MoE triton kernel, and unit tests

Bugfixes:

  • prev_block is None when inferencing
  • weakref is not pickle-able on TP > 1

Summary by CodeRabbit

  • New Features

    • Added fused Mixture-of-Experts (MoE) computation with tensor-parallel support and optimized routing.
    • Introduced new fused top-K routing selection with Triton kernel acceleration.
    • Added new sdar_moe sampler variant for flexible model selection.
  • Bug Fixes

    • Improved robustness when handling missing predecessor blocks in engine logic.
    • Enhanced serialization support for better model checkpointing.
  • Tests

    • Added comprehensive test suites for fused MoE and top-K kernels with multi-dtype and edge-case coverage.

- latest version of transformers will set default value for rope_scaling when it is None
- rope_scaling is not currently implemented in the engine
- stack weights in order to execute MoE GEMMs together
- example: gate_proj + up_proj -> w13
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 31, 2026

📝 Walkthrough

Walkthrough

Added pickle support for block serialization with weakref handling, extended MoE routing with Triton-accelerated kernels for fused top-k and expert computation, introduced expert parameter loading infrastructure, and updated control flow for null-safe predecessor block handling across multiple components.

Changes

Cohort / File(s) Summary
Block Serialization & Null Safety
diffulex/engine/dllm_block.py, diffulex/mixin/multi_block/engine/request.py
Added __getstate__/__setstate__ methods to DllmBlock and DllmBlockBuffer for pickle support with weakref dereferencing. Updated should_force_decode_topk and predecessor checks to gracefully handle None prev_block with conditional logic and guards.
Routing & Top-K Selection
diffulex/moe/topk.py, diffulex_kernel/python/fused_topk_triton.py
Extended TopKRouter with configurable backend selection (impl="torch" or impl="triton"). Added topk_pytorch_reference for PyTorch-based top-k routing and new Triton kernel fused_topk implementing numerically stable softmax/sigmoid with greedy top-k selection, returning (topk_weights, topk_ids, token_expert_mapping). Updated TopKOutput dataclass fields.
Fused MoE Kernels
diffulex_kernel/python/fused_moe_triton.py
Implemented Triton-based fused MoE computation with two-pass kernels: pass 1 loads routed activations and expert gate/up projections, applies SiLU activation; pass 2 computes final projections via w2 weights and atomically accumulates outputs. Includes sorting helper for padded expert tile management.
MoE Block Implementation
diffulex/moe/fused_moe.py, diffulex/moe/moe_impl.py, diffulex/moe/__init__.py
Added FusedSparseMoEBlock using ReplicatedLinear router and fused kernels with TP support. Updated SparseMoEBlock to specify TopKRouter backend as impl="torch". Switched build_mlp_or_moe to instantiate FusedSparseMoEBlock for MoE layers.
Expert Parameter Loading
diffulex/utils/loader.py, diffulex/model/sdar_moe.py
Added stacked_params_mapping class attribute to SDARMoEForDiffusionLM defining expert weight indices. Implemented _parse_expert_id helper and extended checkpoint loading to support stacked_params_mapping-guided expert weight distribution via custom weight_loader hooks.
Kernel Module & Exports
diffulex_kernel/__init__.py
Added lazy loading for fused_topk and fused_moe kernel functions via __getattr__ dispatch; added fused_topk to __all__.
Sampler Configuration
diffulex/sampler/sdar.py
Registered SDARSampler under additional key sdar_moe in AutoSampler.
Test Coverage
test/python/kernel/test_fused_moe.py, test/python/kernel/test_fused_topk.py
Added comprehensive test suites validating fused kernels against PyTorch reference implementations across multiple dimensions, dtypes, edge cases, and parametrized configurations; includes determinism, renormalization, and numerical stability checks.
Configuration
pyproject.toml
Pinned transformers dependency to exact version 4.53.2.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Poem

Kernels dancing, fused and bright, 🐰✨
Experts routed with pure delight,
Pickle safety, blocks stand tall,
Parameter loading conquers all! 🎊

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.40% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main additions to the PR: MoE (Mixture of Experts) support with fused TopK and MoE kernels, which aligns with the extensive changes across multiple MoE-related files and kernel implementations.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (3)
diffulex/moe/__init__.py (1)

10-16: Remove the dead fallback or make it a real branch.

Line 15 returns unconditionally, so Line 16 can never execute. If SparseMoEBlock is still meant to be a fallback, this needs an availability/config check instead of an unreachable return.

Suggested cleanup
 def build_mlp_or_moe(config, layer_idx: int, dense_factory):
     """Build a dense MLP or MoE block according to the config."""
     if is_moe_layer(config, layer_idx):
         return FusedSparseMoEBlock.from_config(config)
-        return SparseMoEBlock.from_config(config)
     return dense_factory()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/__init__.py` around lines 10 - 16, The function build_mlp_or_moe
currently returns FusedSparseMoEBlock.from_config(config) unconditionally,
making the subsequent return of SparseMoEBlock.from_config(config) dead code;
either remove the unreachable fallback or implement a real branch that chooses
SparseMoEBlock when FusedSparseMoEBlock is unavailable. Update build_mlp_or_moe
to check availability (e.g., try/except ImportError or a feature flag) before
calling FusedSparseMoEBlock.from_config(config) and only call
SparseMoEBlock.from_config(config) when the fused implementation is not
available, or delete the redundant return line if the fused block is the sole
supported implementation; reference the symbols build_mlp_or_moe,
FusedSparseMoEBlock, SparseMoEBlock, and is_moe_layer when making the change.
diffulex_kernel/__init__.py (1)

42-50: Keep the lazy-export surface consistent.

fused_topk is exposed through both __getattr__ and __all__, but fused_moe is only exposed through __getattr__ while Line 62 keeps it commented out. If fused_moe is public, export it consistently; if it is private, hiding it in one place and exposing it in another is confusing.

Also applies to: 55-63

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/__init__.py` around lines 42 - 50, The lazy-export surface is
inconsistent: __getattr__ exposes fused_topk and fused_moe but __all__ only
lists fused_topk (fused_moe is commented out). Make the exports consistent by
either adding "fused_moe" to the module-level __all__ list (or uncommenting the
existing entry) if it should be public, or remove/deny export in __getattr__ for
fused_moe if it should be private; update the code paths around __getattr__, the
fused_topk and fused_moe import lines, and the __all__ definition so both
symbols are treated the same way.
diffulex/moe/topk.py (1)

9-10: Lazy-import the Triton backend.

TopKRouter(impl="torch") and topk_pytorch_reference() do not need the kernel package, but the module-level from diffulex_kernel import fused_topk makes diffulex.moe.topk depend on that stack at import time anyway. If diffulex_kernel resolves the Triton module eagerly, CPU/reference users will fail before they can ever select the torch path.

♻️ Suggested change
-from diffulex_kernel import fused_topk
-
 def topk_pytorch_reference(
     router_logits: torch.Tensor,
     top_k: int,
@@
         if impl == "torch":
             self.impl = topk_pytorch_reference
         elif impl == "triton":
+            from diffulex_kernel import fused_topk
             self.impl = fused_topk
         else:
             raise ValueError(f"Unsupported impl: {impl!r}")

Also applies to: 59-64

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/topk.py` around lines 9 - 10, The module currently imports
fused_topk at top-level causing an eager dependency; change to lazy-import
diffulex_kernel.fused_topk only where needed: move the import into the code
paths that actually call it (e.g., inside TopKRouter implementation branch that
selects the Triton backend and inside the function that calls fused_topk), so
TopKRouter(impl="torch") and topk_pytorch_reference() can be imported without
resolving diffulex_kernel; ensure you reference fused_topk by importing it
locally just before use and keep existing function/class names (TopKRouter,
topk_pytorch_reference, fused_topk) to locate the spots to modify.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@diffulex_kernel/python/fused_moe_triton.py`:
- Around line 21-23: Rename the single-letter kernel dimension `I` to a
descriptive name (e.g., `I_dim` or `INPUT_SIZE`) both in the Triton kernel
signature (the `I: tl.constexpr` parameter) and in the wrapper/local variables
that reference it so Ruff E741 is resolved; update every usage (pointer
arithmetic, index calculations, and any calls that pass the argument) to the new
name, including the other occurrences noted around the later blocks (the local
wrapper variable and the other kernel signatures/usages), ensuring all
references (e.g., kernel definition, launch invocation, and any local variable
named `I`) are consistently renamed.
- Around line 266-343: Replace the fragile asserts in _run_fused_moe_kernels
(and validate in fused_moe caller) with explicit runtime checks that raise clear
exceptions: verify w13 and w2 are 3D tensors, check w13.shape[1] % 2 == 0 and
compute I = w13.shape[1] // 2 only after that, ensure w13.shape[2] == H and
w2.shape == (E, H, I), confirm topk_ids and topk_weights are 2D with identical
shapes and that topk_ids.shape[0] == M; raise ValueError (or TypeError) with
descriptive messages naming the offending tensor (w13, w2, topk_ids,
topk_weights) so kernel launches fail fast with clear Python errors.

In `@diffulex/engine/dllm_block.py`:
- Around line 36-47: The file defines __getstate__ and __setstate__ multiple
times which causes the later/older definitions to shadow the new weakref-based
serialization path; consolidate to a single pair of methods that implement the
weakref handling: keep the implementations that use weakref_fn and convert
s['_req'] and s['_dllm_block_buffer'] to/from weak references, remove or merge
the older duplicate __getstate__/__setstate__ definitions so only one canonical
implementation remains, and ensure the final __setstate__/__getstate__ reference
the weakref_fn helper and the _req and _dllm_block_buffer attributes
consistently.

In `@diffulex/moe/fused_moe.py`:
- Around line 29-38: Constructor currently accepts arbitrary hidden_act causing
instantiation of a fused MoE that only supports "silu"; update the validation to
fail fast by checking hidden_act in __init__ (and the alternate
constructor/loader referenced at lines ~148-156, e.g., from_config or similar
factory) and raise a clear ValueError if hidden_act != "silu" that points users
to use the unfused MoE block instead; ensure the check is implemented early in
the FusedMoE initialization path (reference symbols: __init__, hidden_act,
from_config) so unsupported configs never create a fused instance.

In `@test/python/kernel/test_fused_moe.py`:
- Around line 12-18: Rename the ambiguous dimension name `I` to a descriptive
name (e.g., `intermediate_size` or `intermediate_dim`) everywhere in this test
module: update the function signature of fused_moe_pytorch_reference (change
comments and type hints for w13 and w2), update usages inside
fused_moe_pytorch_reference, update the helper function `_run_test` and any
local variables or test locals that use `I` (including the later locals around
lines 58-69) so all occurrences are consistently renamed and lint E741 is
resolved.
- Around line 375-395: The test named test_determinism does not actually
guarantee determinism because top_k=2 allows atomic_add race-induced FP32
variations; update the test to either set top_k=1 (change the local top_k
variable to 1 so fused_moe runs without expert conflicts and true determinism is
validated) or rename the test (e.g., test_approximation) and its docstring to
reflect it verifies bounded numerical closeness for fused_moe with top_k=2; also
update the test name/docstring and any inline comment accordingly so readers and
CI expectations match the chosen behavior.

---

Nitpick comments:
In `@diffulex_kernel/__init__.py`:
- Around line 42-50: The lazy-export surface is inconsistent: __getattr__
exposes fused_topk and fused_moe but __all__ only lists fused_topk (fused_moe is
commented out). Make the exports consistent by either adding "fused_moe" to the
module-level __all__ list (or uncommenting the existing entry) if it should be
public, or remove/deny export in __getattr__ for fused_moe if it should be
private; update the code paths around __getattr__, the fused_topk and fused_moe
import lines, and the __all__ definition so both symbols are treated the same
way.

In `@diffulex/moe/__init__.py`:
- Around line 10-16: The function build_mlp_or_moe currently returns
FusedSparseMoEBlock.from_config(config) unconditionally, making the subsequent
return of SparseMoEBlock.from_config(config) dead code; either remove the
unreachable fallback or implement a real branch that chooses SparseMoEBlock when
FusedSparseMoEBlock is unavailable. Update build_mlp_or_moe to check
availability (e.g., try/except ImportError or a feature flag) before calling
FusedSparseMoEBlock.from_config(config) and only call
SparseMoEBlock.from_config(config) when the fused implementation is not
available, or delete the redundant return line if the fused block is the sole
supported implementation; reference the symbols build_mlp_or_moe,
FusedSparseMoEBlock, SparseMoEBlock, and is_moe_layer when making the change.

In `@diffulex/moe/topk.py`:
- Around line 9-10: The module currently imports fused_topk at top-level causing
an eager dependency; change to lazy-import diffulex_kernel.fused_topk only where
needed: move the import into the code paths that actually call it (e.g., inside
TopKRouter implementation branch that selects the Triton backend and inside the
function that calls fused_topk), so TopKRouter(impl="torch") and
topk_pytorch_reference() can be imported without resolving diffulex_kernel;
ensure you reference fused_topk by importing it locally just before use and keep
existing function/class names (TopKRouter, topk_pytorch_reference, fused_topk)
to locate the spots to modify.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2ad030d3-f366-4b9c-9e61-c3dd60f9c15e

📥 Commits

Reviewing files that changed from the base of the PR and between 9ead055 and 79f111a.

📒 Files selected for processing (15)
  • diffulex/engine/dllm_block.py
  • diffulex/mixin/multi_block/engine/request.py
  • diffulex/model/sdar_moe.py
  • diffulex/moe/__init__.py
  • diffulex/moe/fused_moe.py
  • diffulex/moe/moe_impl.py
  • diffulex/moe/topk.py
  • diffulex/sampler/sdar.py
  • diffulex/utils/loader.py
  • diffulex_kernel/__init__.py
  • diffulex_kernel/python/fused_moe_triton.py
  • diffulex_kernel/python/fused_topk_triton.py
  • pyproject.toml
  • test/python/kernel/test_fused_moe.py
  • test/python/kernel/test_fused_topk.py

Comment on lines +21 to +23
M, # number of real tokens (for bounds checking)
H: tl.constexpr,
I: tl.constexpr,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rename I here as well.

The same single-letter dimension name is tripping Ruff E741 in the kernel signature and the wrapper local. Expanding it will keep the new module lint-clean and make the pointer math easier to scan.

Also applies to: 122-123, 278-280

🧰 Tools
🪛 Ruff (0.15.7)

[error] 23-23: Ambiguous variable name: I

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_moe_triton.py` around lines 21 - 23, Rename the
single-letter kernel dimension `I` to a descriptive name (e.g., `I_dim` or
`INPUT_SIZE`) both in the Triton kernel signature (the `I: tl.constexpr`
parameter) and in the wrapper/local variables that reference it so Ruff E741 is
resolved; update every usage (pointer arithmetic, index calculations, and any
calls that pass the argument) to the new name, including the other occurrences
noted around the later blocks (the local wrapper variable and the other kernel
signatures/usages), ensuring all references (e.g., kernel definition, launch
invocation, and any local variable named `I`) are consistently renamed.

Comment on lines +266 to +343
def _run_fused_moe_kernels(
hidden_states: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
hidden_act: str = "silu",
) -> torch.Tensor:
assert hidden_states.ndim == 2
assert hidden_act == "silu"

M, H = hidden_states.shape
E = w13.shape[0]
I = w13.shape[1] // 2

BLOCK_M = 64
BLOCK_I = 64
BLOCK_H = 64

sorted_token_ids, sorted_weights, expert_ids, num_tokens_post_pad = _moe_sorting(
topk_ids, topk_weights, num_experts=E, block_size=BLOCK_M,
)

if num_tokens_post_pad == 0:
return torch.zeros((M, H), dtype=hidden_states.dtype, device=hidden_states.device)

num_tiles_m = num_tokens_post_pad // BLOCK_M

intermediate = torch.zeros(
(num_tokens_post_pad, I),
dtype=hidden_states.dtype,
device=hidden_states.device,
)

# Pass 1
grid_pass1 = (num_tiles_m, triton.cdiv(I, BLOCK_I))
_fused_moe_kernel_pass1[grid_pass1](
hidden_states, w13, intermediate,
sorted_token_ids, expert_ids, num_tokens_post_pad,
M,
H=H, I=I,
stride_hs_m=hidden_states.stride(0), stride_hs_h=hidden_states.stride(1),
stride_w13_e=w13.stride(0), stride_w13_i=w13.stride(1), stride_w13_h=w13.stride(2),
stride_inter_t=intermediate.stride(0), stride_inter_i=intermediate.stride(1),
BLOCK_M=BLOCK_M, BLOCK_I=BLOCK_I, BLOCK_H=BLOCK_H,
)

# Pass 2
output = torch.zeros((M, H), dtype=hidden_states.dtype, device=hidden_states.device)

grid_pass2 = (num_tiles_m, triton.cdiv(H, BLOCK_H))
_fused_moe_kernel_pass2[grid_pass2](
intermediate, w2, output,
sorted_token_ids, expert_ids, sorted_weights, num_tokens_post_pad,
M,
I=I, H=H,
stride_inter_t=intermediate.stride(0), stride_inter_i=intermediate.stride(1),
stride_w2_e=w2.stride(0), stride_w2_h=w2.stride(1), stride_w2_i=w2.stride(2),
stride_out_m=output.stride(0), stride_out_h=output.stride(1),
BLOCK_M=BLOCK_M, BLOCK_H=BLOCK_H, BLOCK_I=BLOCK_I,
)

return output


def fused_moe(
hidden_states: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
hidden_act: str = "silu",
) -> torch.Tensor:
return _run_fused_moe_kernels(
hidden_states, w13, w2, topk_weights, topk_ids, top_k, hidden_act
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat diffulex_kernel/python/fused_moe_triton.py

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 12175


Replace asserts with explicit shape validation before kernel launch.

The function currently has no guards before calling the Triton kernels except two asserts that disappear under python -O. When tensors have incompatible shapes, the kernels fail with opaque CUDA errors rather than clear Python exceptions. Add validation for:

  • w13 and w2 must be 3D tensors (kernels use .stride(0), .stride(1), .stride(2))
  • w13.shape[1] must be divisible by 2 (intermediate size I = w13.shape[1] // 2)
  • w13.shape[2] == H and w2.shape == (E, H, w13.shape[1] // 2) (dimensions must match kernel expectations)
  • topk_ids and topk_weights must be 2D with the same shape (required by _moe_sorting)
  • topk_ids.shape[0] == M (routing tensors must have one row per token in hidden_states)
🧰 Tools
🪛 Ruff (0.15.7)

[error] 280-280: Ambiguous variable name: I

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex_kernel/python/fused_moe_triton.py` around lines 266 - 343, Replace
the fragile asserts in _run_fused_moe_kernels (and validate in fused_moe caller)
with explicit runtime checks that raise clear exceptions: verify w13 and w2 are
3D tensors, check w13.shape[1] % 2 == 0 and compute I = w13.shape[1] // 2 only
after that, ensure w13.shape[2] == H and w2.shape == (E, H, I), confirm topk_ids
and topk_weights are 2D with identical shapes and that topk_ids.shape[0] == M;
raise ValueError (or TypeError) with descriptive messages naming the offending
tensor (w13, w2, topk_ids, topk_weights) so kernel launches fail fast with clear
Python errors.

Comment on lines +36 to +47
def __getstate__(self):
s = self.__dict__.copy()
s['_req'] = s['_req']()
if '_dllm_block_buffer' in s:
s['_dllm_block_buffer'] = s['_dllm_block_buffer']()
return s

def __setstate__(self, state):
s = self.__dict__ = state.copy()
s['_req'] = weakref_fn(s['_req'])
if '_dllm_block_buffer' in s:
s['_dllm_block_buffer'] = weakref_fn(s['_dllm_block_buffer'])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

These new pickle hooks are shadowed by the older ones below.

Lines 73-80 and 223-229 redefine __getstate__/__setstate__, so Python drops the versions added here. That makes the new weakref rehydration path unreachable and leaves two conflicting serialization implementations in the same file.

Also applies to: 198-205

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/engine/dllm_block.py` around lines 36 - 47, The file defines
__getstate__ and __setstate__ multiple times which causes the later/older
definitions to shadow the new weakref-based serialization path; consolidate to a
single pair of methods that implement the weakref handling: keep the
implementations that use weakref_fn and convert s['_req'] and
s['_dllm_block_buffer'] to/from weak references, remove or merge the older
duplicate __getstate__/__setstate__ definitions so only one canonical
implementation remains, and ensure the final __setstate__/__getstate__ reference
the weakref_fn helper and the _req and _dllm_block_buffer attributes
consistently.

Comment on lines +29 to +38
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_experts: int,
top_k: int,
*,
hidden_act: str = "silu",
norm_topk_prob: bool = True,
) -> None:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fail fast on unsupported activations.

from_config() forwards arbitrary config.hidden_act, but the backend currently only supports "silu". Right now an unsupported model will instantiate fine and then abort on its first forward. Guard it here or route those configs to the unfused MoE block.

♻️ Suggested change
         self.num_experts = num_experts
         self.top_k = top_k
         self.norm_topk_prob = norm_topk_prob
-        self.hidden_act = hidden_act
+        if hidden_act != "silu":
+            raise ValueError(
+                "FusedSparseMoEBlock only supports hidden_act='silu'"
+            )
+        self.hidden_act = hidden_act

Also applies to: 148-156

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@diffulex/moe/fused_moe.py` around lines 29 - 38, Constructor currently
accepts arbitrary hidden_act causing instantiation of a fused MoE that only
supports "silu"; update the validation to fail fast by checking hidden_act in
__init__ (and the alternate constructor/loader referenced at lines ~148-156,
e.g., from_config or similar factory) and raise a clear ValueError if hidden_act
!= "silu" that points users to use the unfused MoE block instead; ensure the
check is implemented early in the FusedMoE initialization path (reference
symbols: __init__, hidden_act, from_config) so unsupported configs never create
a fused instance.

Comment on lines +12 to +18
def fused_moe_pytorch_reference(
hidden_states: torch.Tensor, # (M, H)
w13: torch.Tensor, # (E, 2*I, H)
w2: torch.Tensor, # (E, H, I)
topk_weights: torch.Tensor, # (M, top_k)
topk_ids: torch.Tensor, # (M, top_k)
top_k: int,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rename I across this file.

Ruff is already flagging every I here as E741, so the new test module will stay lint-red until the dimension name is expanded (intermediate_size, intermediate_dim, etc.). The same rename needs to be applied in _run_test and the later test locals.

Also applies to: 58-69

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@test/python/kernel/test_fused_moe.py` around lines 12 - 18, Rename the
ambiguous dimension name `I` to a descriptive name (e.g., `intermediate_size` or
`intermediate_dim`) everywhere in this test module: update the function
signature of fused_moe_pytorch_reference (change comments and type hints for w13
and w2), update usages inside fused_moe_pytorch_reference, update the helper
function `_run_test` and any local variables or test locals that use `I`
(including the later locals around lines 58-69) so all occurrences are
consistently renamed and lint E741 is resolved.

Comment on lines +375 to +395
def test_determinism():
"""Same inputs → same outputs across two runs."""
device = "cuda"
M, H, I, E, top_k = 64, 128, 64, 8, 2

torch.manual_seed(42)
hidden_states = torch.randn(M, H, device=device) * 0.1
w13 = torch.randn(E, 2 * I, H, device=device) * 0.1
w2 = torch.randn(E, H, I, device=device) * 0.1
router_logits = torch.randn(M, E, device=device)
scores = torch.softmax(router_logits, dim=-1)
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids.to(torch.int32)

out1 = fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, top_k)
out2 = fused_moe(hidden_states, w13, w2, topk_weights, topk_ids, top_k)

# Note: atomic_add may have non-deterministic ordering for fp32,
# but values should still be very close
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "*.py" | grep -E "(fused_moe|test_fused)" | head -20

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 217


🏁 Script executed:

find . -type f -name "*fused_moe*" | head -20

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 177


🏁 Script executed:

cat -n ./diffulex_kernel/python/fused_moe_triton.py | head -150

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 6333


🏁 Script executed:

rg "atomic_add" ./diffulex_kernel/python/fused_moe_triton.py -A 3 -B 3

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 351


🏁 Script executed:

cat -n ./test/python/kernel/test_fused_moe.py | sed -n '370,400p'

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 1621


🏁 Script executed:

cat -n ./diffulex/moe/fused_moe.py | head -100

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 3809


🏁 Script executed:

rg "determinism|atomic" ./test/python/kernel/test_fused_moe.py -B 2 -A 2

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 497


🏁 Script executed:

wc -l ./diffulex_kernel/python/fused_moe_triton.py

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 114


🏁 Script executed:

cat -n ./diffulex_kernel/python/fused_moe_triton.py | sed -n '150,180p'

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 1515


🏁 Script executed:

rg "top_k" ./test/python/kernel/test_fused_moe.py | head -20

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 960


🏁 Script executed:

rg "deterministic|CUDA_LAUNCH_BLOCKING" . --type py -l

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 216


🏁 Script executed:

cat -n ./diffulex_kernel/python/fused_moe_triton.py | sed -n '175,185p'

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 533


Rename or adjust this test—it does not validate determinism.

With top_k=2 and tl.atomic_add(), thread ordering is non-deterministic on GPU, causing rounding variations. The test mitigates this with relaxed tolerances (atol=1e-5, rtol=1e-5), making it a bounded-approximation test, not a determinism test. Either:

  1. Use top_k=1 to eliminate expert conflicts and actual determinism, or
  2. Rename to test_approximation() or similar to reflect what it actually validates.
🧰 Tools
🪛 Ruff (0.15.7)

[error] 378-378: Ambiguous variable name: I

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@test/python/kernel/test_fused_moe.py` around lines 375 - 395, The test named
test_determinism does not actually guarantee determinism because top_k=2 allows
atomic_add race-induced FP32 variations; update the test to either set top_k=1
(change the local top_k variable to 1 so fused_moe runs without expert conflicts
and true determinism is validated) or rename the test (e.g., test_approximation)
and its docstring to reflect it verifies bounded numerical closeness for
fused_moe with top_k=2; also update the test name/docstring and any inline
comment accordingly so readers and CI expectations match the chosen behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant