Skip to content
2,513 changes: 2,513 additions & 0 deletions oss_log_bias.txt

Large diffs are not rendered by default.

1,731 changes: 1,731 additions & 0 deletions repro_attention_dp.py

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions tpu_inference/layers/common/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,19 @@ def from_vllm_config(cls,
if enable_dp_attention:
# Replicate attention layer when num_kv_heads < TP
num_kv_heads = vllm_config.model_config.get_total_num_kv_heads()

kv_dtype = utils.get_jax_dtype_from_str_dtype(
vllm_config.cache_config.cache_dtype) or jnp.bfloat16
packing = 4 // jnp.dtype(kv_dtype).itemsize
# When num_kv_heads * 2 / packing < TP, tensor parallelism would
# duplicate KV heads across devices, wasting kv cache memory.
# Use attention DP instead to reduce per-device num_kv_heads and
# eliminate this waste.

# if head_dim is 64, multiply packing by 2
if vllm_config.model_config.get_head_size() == 64:
packing *= 2

num_kv_heads_per_device_in_kv_cache = (num_kv_heads * 2) / packing
attn_dp = max(
int(tensor_parallelism // num_kv_heads_per_device_in_kv_cache),
Expand Down Expand Up @@ -166,10 +172,10 @@ def validate(cls, vllm_config, sharding_strategy):
f"LoRA is not supported with data parallelism "
f"(DP size: {total_dp_size}). Please disable LoRA or "
f"set data parallelism to 1.")
if not os.environ.get("NEW_MODEL_DESIGN", False):
raise ValueError(
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
"NEW_MODEL_DESIGN=True.")
# if not os.environ.get("NEW_MODEL_DESIGN", False):
# raise ValueError(
# "Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
# "NEW_MODEL_DESIGN=True.")

@property
def total_dp_size(self) -> int:
Expand Down
135 changes: 90 additions & 45 deletions tpu_inference/layers/vllm/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tpu_inference.layers.vllm.linear_common import \
slice_sharded_tensor_for_concatenation

from tpu_inference.layers.common.sharding import ShardingAxisName
P = PartitionSpec


Expand Down Expand Up @@ -110,7 +110,7 @@ def tensor_sharded_gmm_merged_column_parallel(
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m//mesh.shape["data"], k, n, g)

_gmm = functools.partial(
gmm,
Expand All @@ -123,16 +123,25 @@ def tensor_sharded_gmm_merged_column_parallel(
gmm_result = shard_map(
_gmm,
mesh=mesh,
in_specs=(P(), P(None, "model", None), P()),
out_specs=(P(None, "model")),
in_specs=(P(ShardingAxisName.MLP_DATA, None), P(None, ShardingAxisName.MLP_TENSOR, None), P(ShardingAxisName.MLP_DATA)),
out_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR)),
check_rep=False,
)(lhs, rhs, group_sizes)

if rhs_bias is not None:
rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
rhs_bis = jnp.repeat(rhs_bias_local, group_sizes_global, 0, total_repeat_length=m//mesh.shape["data"])
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)

gmm_result = shard_map(
_add_bias,
mesh=mesh,
in_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR), P(None, ShardingAxisName.MLP_TENSOR), P(ShardingAxisName.MLP_DATA)),
out_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR)),
check_rep=False,
)(gmm_result, rhs_bias, group_sizes)

n_shards = mesh.shape["model"]
n_shards = mesh.shape['model'] * mesh.shape.get('attn_dp', 1)
output_sizes = [intermediate_size, intermediate_size]

return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
Expand All @@ -150,7 +159,7 @@ def tensor_sharded_gmm_row_parallel(
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m//mesh.shape["data"], k, n, g)

_gmm = functools.partial(
gmm,
Expand All @@ -162,19 +171,29 @@ def tensor_sharded_gmm_row_parallel(

def _gmm_all_reduce(lhs, rhs, group_sizes):
r = _gmm(lhs, rhs, group_sizes)
return jax.lax.psum(r, axis_name="model")
return jax.lax.psum(r, axis_name=ShardingAxisName.MLP_TENSOR)

gmm_result = shard_map(
_gmm_all_reduce,
mesh=mesh,
in_specs=(P(None, "model"), P(None, None, "model"), P()),
out_specs=(P()),
check_rep=False,
in_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR),
P(None, None, ShardingAxisName.MLP_TENSOR), P(ShardingAxisName.MLP_DATA)),
out_specs=(P(ShardingAxisName.MLP_DATA)),
check_rep=False,
)(lhs, rhs, group_sizes)

# jax.debug.print("gmm_result before bias {} {}", gmm_result.sum(), gmm_result.ravel()[:10])
if rhs_bias is not None:
rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
rhs_bis = jnp.repeat(rhs_bias_local, group_sizes_global, 0, total_repeat_length=m//mesh.shape["data"])
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)

gmm_result = shard_map(
_add_bias,
mesh=mesh,
in_specs=(P(ShardingAxisName.MLP_DATA), P(), P(ShardingAxisName.MLP_DATA)),
out_specs=(P(ShardingAxisName.MLP_DATA)),
check_rep=False,
)(gmm_result, rhs_bias, group_sizes)

return gmm_result

Expand All @@ -191,13 +210,12 @@ def expert_sharded_gmm(
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m//mesh.shape["data"], k, n, g)

num_experts_per_shard = num_experts // ep_size
group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
group_offset = jax.lax.with_sharding_constraint(
group_offset, NamedSharding(mesh, P("model")))

group_offset, NamedSharding(mesh, P(ShardingAxisName.EXPERT)))
def _gmm(lhs, rhs, group_sizes, group_offset):
# Group offset for this shard. `group_offset` is sharded, and in this
# sharded function, it has only 1 element and `group_offset.shape` is
Expand Down Expand Up @@ -236,8 +254,9 @@ def _gmm(lhs, rhs, group_sizes, group_offset):
gmm_res = shard_map(
_gmm,
mesh=mesh,
in_specs=(P(), P("model", None, None), P(), P("model")),
out_specs=(P("model", None)),
in_specs=(P(ShardingAxisName.MLP_DATA, None), P(ShardingAxisName.EXPERT, None,
None), P(ShardingAxisName.MLP_DATA), P(ShardingAxisName.EXPERT)),
out_specs=(P(ShardingAxisName.EXPERT, None)),
check_rep=False,
)(lhs, rhs, group_sizes, group_offset)

Expand All @@ -256,12 +275,11 @@ def _gmm(lhs, rhs, group_sizes, group_offset):
recv_sizes = send_sizes

input_offsets = jax.lax.with_sharding_constraint(
input_offsets, NamedSharding(mesh, P("model")))
input_offsets, NamedSharding(mesh, P(ShardingAxisName.EXPERT)))
send_sizes = jax.lax.with_sharding_constraint(
send_sizes, NamedSharding(mesh, P("model")))
send_sizes, NamedSharding(mesh, P(ShardingAxisName.EXPERT)))
output_offsets = jax.lax.with_sharding_constraint(
output_offsets, NamedSharding(mesh, P("model")))

output_offsets, NamedSharding(mesh, P(ShardingAxisName.EXPERT)))
def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
recv_sizes):
output = jnp.zeros_like(operand)
Expand Down Expand Up @@ -292,8 +310,7 @@ def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
send_sizes_of_shard,
output_offsets_of_shard,
recv_sizes_of_shard,
axis_name="model")

axis_name=ShardingAxisName.EXPERT)
# Use ragged_all_to_all to send the result from gmm for each expert to all
# the shards. In the working example, the result would be:
# A, A, A, A A, A, A, A A, A, A, A A, A, A, A
Expand All @@ -314,7 +331,8 @@ def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
return shard_map(
_ragged_all_to_all,
mesh=mesh,
in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
in_specs=(P(ShardingAxisName.EXPERT, None), P(ShardingAxisName.EXPERT),
P(ShardingAxisName.EXPERT), P(ShardingAxisName.EXPERT), P()),
out_specs=(P()),
check_rep=False,
)(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
Expand Down Expand Up @@ -350,13 +368,12 @@ def fused_moe_func(
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.size // hidden_size
assert global_num_experts == w1.shape[0]
ep_size = mesh.shape["model"] # only used if use_ep is True.
ep_size = mesh.shape['model'] * mesh.shape.get("attn_dp", 1) # only used if use_ep is True.
intermediate_size = w2.shape[-1]
dtype = hidden_states.dtype
assert (num_tokens * topk) % 16 == 0, (
"The kernel requires num_tokens * topk to be a multiple of "
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")

hidden_states = hidden_states.reshape(num_tokens, hidden_size)
gating_output = gating_output.reshape(num_tokens, global_num_experts)

Expand All @@ -366,15 +383,28 @@ def fused_moe_func(
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
topk_weights = topk_weights.astype(dtype)

topk_indices_flat = topk_indices.flatten()
topk_argsort_indices = jnp.argsort(topk_indices_flat)
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
token_indices_sorted = token_indices[topk_argsort_indices]
group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)

x = hidden_states[token_indices_sorted]

def _process_tokens_locally(hidden_states_local, topk_indices_local):
num_tokens_local = hidden_states_local.shape[0]
topk_indices_flat = topk_indices_local.flatten()
topk_argsort_indices = jnp.argsort(topk_indices_flat)
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
token_indices = jnp.arange(num_tokens_local, dtype=jnp.int32).repeat(topk)
token_indices_sorted = token_indices[topk_argsort_indices]
group_sizes_local = jnp.bincount(topk_indices_flat, length=global_num_experts)

x = hidden_states_local[token_indices_sorted]
return x, group_sizes_local, topk_argsort_revert_indices

x, group_sizes, topk_argsort_revert_indices = shard_map(
_process_tokens_locally,
mesh=mesh,
in_specs=(P(ShardingAxisName.ATTN_DATA, None), P(ShardingAxisName.ATTN_DATA, None)),
out_specs=(P(ShardingAxisName.ATTN_DATA, None), P(ShardingAxisName.ATTN_DATA), P(ShardingAxisName.ATTN_DATA)),
check_rep=False,
)(hidden_states, topk_indices)

# jax.debug.print("hidden_state before gmm {} {}", x.sum(), x.ravel()[:10])
# jax.debug.print("group_sizes {} {}", group_sizes.sum(), group_sizes)
if use_ep:
x = expert_sharded_gmm(
x,
Expand All @@ -396,9 +426,11 @@ def fused_moe_func(
mesh=mesh,
intermediate_size=intermediate_size,
)
# jax.debug.print("hidden_state after first gmm x1 {} {}", x1.sum(), x1.ravel()[:10])
# jax.debug.print("hidden_state after first gmm x2 {} {}", x2.sum(), x2.ravel()[:10])

x = activation_fn(activation, x1, x2)

# jax.debug.print("hidden_state after activation {} {}", x.sum(), x.ravel()[:10])
if use_ep:
x = expert_sharded_gmm(
x,
Expand All @@ -411,7 +443,7 @@ def fused_moe_func(
)
else:
x = jax.lax.with_sharding_constraint(
x, NamedSharding(mesh, P(None, "model")))
x, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, "model")))
x = tensor_sharded_gmm_row_parallel(
x,
w2,
Expand All @@ -420,14 +452,27 @@ def fused_moe_func(
transpose_rhs=True,
mesh=mesh,
)

x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
x = x * jnp.expand_dims(topk_weights, axis=-1)
x = x.sum(axis=-2)
# jax.debug.print("hidden_state after second gmm {} {}", x.sum(), x.ravel()[:10])

def _finalize_output(x_local, topk_argsort_revert_indices_local, topk_weights_local):
x_local = x_local[topk_argsort_revert_indices_local].reshape(-1, topk, hidden_size)
x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
x_local = x_local.sum(axis=-2)
return x_local

x = shard_map(
_finalize_output,
mesh=mesh,
in_specs=(P(ShardingAxisName.ATTN_DATA, None), P(ShardingAxisName.ATTN_DATA), P(ShardingAxisName.ATTN_DATA, None)),
out_specs=(P(ShardingAxisName.ATTN_DATA, None)),
check_rep=False,
)(x, topk_argsort_revert_indices, topk_weights)
# jax.debug.print("hidden_state after finalize output {} {}", x.sum(), x.ravel()[:10])
x = x.reshape(orig_shape)

if reduce_results:
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA)))
# jax.debug.print("hidden_state after reducing result {} {}", x.sum(), x.ravel()[:10])
return x


Expand Down
28 changes: 21 additions & 7 deletions tpu_inference/layers/vllm/quantization/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ReplicatedLinear,
RowParallelLinear)

from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference.layers.vllm.linear_common import \
get_model_matmul_fusion_assignment
from tpu_inference.utils import TPU_SECOND_LAST_MINOR
Expand All @@ -34,15 +35,23 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
self.input_sharding = None
self.output_sharding = None
self.tp_size = self.mesh.shape['model'] * self.mesh.shape.get(
'attn_dp', 1)

if isinstance(layer, RowParallelLinear):
self.weight_sharding = P(None, "model")
self.weight_sharding = P(None, ShardingAxisName.MLP_TENSOR)
if self.enable_sequence_parallelism:
self.output_sharding = P("model", None)
self.output_sharding = P(ShardingAxisName.MLP_TENSOR, None)
elif isinstance(layer, ColumnParallelLinear):
self.weight_sharding = P("model", None)
if isinstance(layer, QKVParallelLinear):
self.input_sharding = P(ShardingAxisName.ATTN_DATA, None)
self.weight_sharding = P('model', None)
self.output_sharding = P(ShardingAxisName.ATTN_DATA, "model")
else:
self.weight_sharding = P(ShardingAxisName.MLP_TENSOR, None)

if self.enable_sequence_parallelism:
self.input_sharding = P("model", None)
self.input_sharding = P(ShardingAxisName.MLP_TENSOR, None)

if isinstance(layer, MergedColumnParallelLinear) or isinstance(
layer, QKVParallelLinear):
Expand All @@ -61,13 +70,18 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
" bad performance.", type(layer))

self.bias_sharding = P(self.weight_sharding[0])
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
if isinstance(self.weight_sharding[0], tuple):
self.n_shards = 1
for axis in self.weight_sharding[0]:
self.n_shards *= self.mesh.shape.get(axis, 1)
else:
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)

def get_input_sharding(self, x: torchax.tensor.Tensor):
if self.enable_sequence_parallelism:
token_num = x.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
if token_num // self.tp_size >= TPU_SECOND_LAST_MINOR:
return self.input_sharding
else:
return None
Expand All @@ -77,7 +91,7 @@ def get_output_sharding(self, x: torchax.tensor.Tensor):
if self.enable_sequence_parallelism:
token_num = x.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
if token_num // self.tp_size >= TPU_SECOND_LAST_MINOR:
return self.output_sharding
else:
return None
Expand Down
Loading