Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions tpu_inference/kernels/fused_moe/v1/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jax import lax
from jax._src import dtypes
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu

P = jax.sharding.PartitionSpec
Expand Down Expand Up @@ -144,7 +145,7 @@ def _fused_ep_moe_kernel(
a2a_acc_sem,
*,
top_k: int,
ep_name: str,
ep_axis_name: str,
# Kernel tuning params.
bt: int, # Block size of local_num_tokens.
bf: int, # Block size of intermediate_size.
Expand All @@ -155,8 +156,8 @@ def _fused_ep_moe_kernel(
bd1c: int, # Compute size of block hidden_size.
bd2c: int, # Compute size of block hidden_size.
):
my_id = lax.axis_index(ep_name)
num_devices = lax.axis_size(ep_name)
my_id = lax.axis_index(ep_axis_name)
num_devices = lax.axis_size(ep_axis_name)
local_num_tokens = tokens_hbm.shape[0]
local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
# num_experts = local_num_experts * num_devices
Expand Down Expand Up @@ -186,13 +187,13 @@ def sync_barrier():
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
device_id=right_id,
device_id_type=pltpu.DeviceIdType.LOGICAL,
device_id=(0, right_id),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 1)

def start_fetch_b_gating(bt_id, priority=0):
is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt)
is_valid = jnp.logical_and(bt_id >= 0, bt_id < num_bt)
sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt)
bt_sem_id = (bt_id + 2) % 2
b_gating_sem = local_sems.at[bt_sem_id, 0]
Expand Down Expand Up @@ -276,7 +277,7 @@ def _all_reduce_metadata(
dst_ref=d2e_count_vmem.at[row_id],
send_sem=send_sem,
recv_sem=recv_sem,
device_id=(right_id, ),
device_id=(0, right_id),
device_id_type=pltpu.DeviceIdType.MESH,
).wait()
row_id = (row_id + num_devices - 1) % num_devices
Expand Down Expand Up @@ -358,7 +359,10 @@ def start_a2a_scatter(bt_id, e_sem_id, local_e_id):
pl.ds(start, remote_sz)],
send_sem=send_sems.at[e_sem_id],
recv_sem=recv_sems.at[e_sem_id],
device_id=(recv_id, ),
device_id=(
0,
recv_id,
),
).start()
a2a_s_sends_x2_smem[e_sem_id] = send_sz

Expand Down Expand Up @@ -402,7 +406,7 @@ def start_a2a_gather(bt_id, e_sem_id, local_e_id):
dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
send_sem=send_sems.at[e_sem_id],
recv_sem=a2a_gather_sem,
device_id=(recv_id, ),
device_id=(0, recv_id),
).start()
start += sz

Expand All @@ -412,7 +416,7 @@ def wait_a2a_gather_send(bt_id, e_sem_id, local_e_id):
sz = expert_sizes_x2_smem[bt_sem_id, 0, my_e_id]
local_sz = d2e_count_x2_smem[bt_sem_id, my_id, 0, my_e_id]
remote_sz = sz - local_sz
is_valid = jnp.logical_and(0 <= local_e_id, local_e_id
is_valid = jnp.logical_and(local_e_id >= 0, local_e_id
< local_num_experts)
remote_sz = lax.select(is_valid, remote_sz, 0)
pltpu.make_async_copy(
Expand Down Expand Up @@ -731,7 +735,7 @@ def start_send_bo(bt_id, priority=0):
).start(priority=priority)

def wait_send_bo(bt_id):
is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt)
is_valid = jnp.logical_and(bt_id >= 0, bt_id < num_bt)
sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt)
bt_sem_id = (bt_id + 2) % 2
b_output_sem = local_sems.at[bt_sem_id, 4]
Expand Down Expand Up @@ -831,6 +835,7 @@ def _():
"bfc",
"bd1c",
"bd2c",
"ep_axis_name",
],
)
def fused_ep_moe(
Expand All @@ -850,12 +855,14 @@ def fused_ep_moe(
bfc: int,
bd1c: int,
bd2c: int,
ep_axis_name: str = 'model',
):
if len(mesh.axis_names) != 1:
raise ValueError("Mesh must have only one axis")
# Assert all other axes have length of 1
assert len(mesh.shape) == 2, "Expect 2D mesh in tpu-inference"
assert 'data' in mesh.shape and mesh.shape['data'] == 1, \
"Expect data axis size of 1 in tpu-inference"

ep_name = mesh.axis_names[0]
ep_size = mesh.axis_sizes[0]
ep_size = mesh.shape[ep_axis_name]
num_devices = ep_size

num_tokens, actual_hidden_size = tokens.shape
Expand Down Expand Up @@ -907,7 +914,7 @@ def fused_ep_moe(
functools.partial(
_fused_ep_moe_kernel,
top_k=top_k,
ep_name=ep_name,
ep_axis_name=ep_axis_name,
bt=bt,
bf=bf,
bd1=bd1,
Expand Down Expand Up @@ -999,11 +1006,13 @@ def fused_ep_moe(
))

@jax.jit
@jax.shard_map(
@functools.partial(
shard_map.shard_map,
mesh=mesh,
in_specs=(P(ep_name), P(ep_name), P(ep_name), P(ep_name), P()),
out_specs=P(ep_name),
check_vma=False,
in_specs=(P(ep_axis_name), P(ep_axis_name), P(ep_axis_name),
P(ep_axis_name), P()),
out_specs=P(ep_axis_name),
check_rep=False,
)
def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch):
return fused_moe(
Expand Down
69 changes: 36 additions & 33 deletions tpu_inference/layers/jax/moe/gpt_oss_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import jax.numpy as jnp
from flax import nnx
from flax.typing import Sharding
from jax.sharding import Mesh
from jaxtyping import Float

from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
from tpu_inference.layers.jax.base import create_param
from tpu_inference.layers.jax.layers import FlaxUtils
from tpu_inference.layers.jax.moe.moe import Router
Expand Down Expand Up @@ -46,13 +48,7 @@ def __call__(self, x_TD: Float):

router_logits_TE += self.bias_E.value

weights_TX, selected_experts_TX = jax.lax.top_k(
router_logits_TE, self.num_experts_per_tok)

normalized_weights_TX = jax.nn.softmax(weights_TX.astype(self.dtype),
axis=-1)

return normalized_weights_TX, selected_experts_TX
return router_logits_TE


def _swiglu(x: Float, alpha: Float, limit: Float) -> Float:
Expand Down Expand Up @@ -90,37 +86,44 @@ class GptOssMoE(nnx.Module):

random_init: bool = False

mesh: Mesh

def __call__(self, x_TD: Float) -> Float:
"""Performs the forward pass for the GPT-OSS MoE layer."""
x_TD = jnp.asarray(x_TD, self.dtype)
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)

weights_TX, indices_TX = self.router(x_TD)

# First MLP layer (up-projection)
with jax.named_scope("MLP #1"):
up_proj_TEF2 = jnp.einsum('TD,EDF -> TEF', x_TD,
self.mlp1_weight_EDF2.value)
up_proj_TEF2 += self.mlp1_bias_EF2.value

fuse_TEF = _swiglu(up_proj_TEF2,
alpha=self.swiglu_alpha,
limit=self.swiglu_limit)

# Second MLP layer (down-projection)
with jax.named_scope("MLP #2"):
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
self.mlp2_weight_EFD.value)
down_proj_TED += self.mlp2_bias_ED.value

# Weighted sum of expert outputs
with jax.named_scope("sum"):
indices_for_gather = indices_TX[..., None]
gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED,
indices_for_gather,
axis=1)
output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED,
weights_TX)
router_logits_TE = self.router(x_TD)

block_size = {
"bt": 32,
"bf": 512,
"bd1": 512,
"bd2": 512,
"btc": 32,
"bfc": 256,
"bd1c": 256,
"bd2c": 256,
}
ep_axis_name = self.efd_sharding[0]
# TODO: Currently, we must reshape the tensors to fit the MoE kernel's
# required shape. We will eliminate this step and load the tensors in
# their desired final shape once the weight loading process(with fp4
# support) is finalized.
mlp1_weight_E2DF = jnp.swapaxes(
jnp.reshape(self.mlp1_weight_EDF2.value,
(self.num_local_experts, self.hidden_size, 2,
self.intermediate_size_moe)), 1, 2)
output_TD = fused_ep_moe(
mesh=self.mesh,
tokens=x_TD,
w1=mlp1_weight_E2DF,
w2=self.mlp2_weight_EFD.value,
gating_output=router_logits_TE,
top_k=self.router.num_experts_per_tok,
ep_axis_name=ep_axis_name,
**block_size,
)

return output_TD.astype(self.dtype)

Expand Down
19 changes: 12 additions & 7 deletions tpu_inference/models/jax/gpt_oss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import re
from dataclasses import dataclass
from typing import List, Optional, Tuple

import jax
import jax.numpy as jnp
Expand All @@ -9,17 +8,22 @@
from flax.typing import PRNGKey
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from vllm.config import VllmConfig

from tpu_inference.layers.jax.attention.gpt_oss_attention import (
AttentionMetadata, GptOssAttention)
AttentionMetadata,
GptOssAttention,
)
from tpu_inference.layers.jax.constants import KVCacheType
from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm
from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
from tpu_inference.layers.jax.transformer_block import TransformerBlock
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.utils.weight_utils import (
get_param, model_weights_generator, print_param_info)
get_param,
model_weights_generator,
print_param_info,
)
from vllm.config import VllmConfig

logger = init_logger(__name__)

Expand Down Expand Up @@ -136,6 +140,7 @@ def __init__(self,
edf_sharding=('model', None, None),
efd_sharding=('model', None, None),
ed_sharding=('model', None),
mesh=self.mesh,
)

block = TransformerBlock(
Expand Down Expand Up @@ -180,7 +185,7 @@ def __init__(self,
def apply(self, variables, *args, **kwargs):
return self.__call__(*args, **kwargs)

def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
def load_weights(self, rng: PRNGKey, cache_dir: str | None = None):
"""Loads and transforms all weights from a checkpoint"""
self.rng = nnx.Rngs(rng)

Expand Down Expand Up @@ -328,11 +333,11 @@ def get_slice(index):

def __call__(
self,
kv_caches: List[jax.Array],
kv_caches: list[jax.Array],
input_ids: jax.Array,
attention_metadata: AttentionMetadata,
*args,
) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
) -> tuple[list[KVCacheType], jax.Array, list[jax.Array]]:
is_prefill = False
x = self.embedder.encode(input_ids)

Expand Down
Loading