Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tpu_inference/layers/jax/pp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List, Protocol

from flax import nnx
from vllm.distributed import get_pp_group
from vllm.distributed.utils import get_pp_indices


class PPMissingLayer(nnx.Module):
"""
A placeholder layer for missing layers in a pipeline parallel model.
"""

def __init__(self, *args, **kwargs):
pass

def __call__(self, *args, **kwargs):
"""Return the first arg from args or the first value from kwargs."""
return args[0] if args else next(iter(kwargs.values()))


class LayerFn(Protocol):

def __call__(self) -> nnx.Module:
...


def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
) -> tuple[int, int, List[nnx.Module]]:
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)

layers = [PPMissingLayer() for _ in range(start_layer)] \
+ [layer_fn() for _ in range(start_layer, end_layer)] \
+ [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]

return start_layer, end_layer, layers
4 changes: 3 additions & 1 deletion tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def get_flax_model(
hidden_states_sharding, # aux hidden states
),
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
static_argnums=6, #6 is layer_name_to_kvcache_index
static_argnums=(
6, 9, 10
), #6 is layer_name_to_kvcache_index, 9 is is_first_rank, 10 is is_last_rank
)
def run_model(graphdef, state, *args):
model = nnx.merge(graphdef, state)
Expand Down
111 changes: 78 additions & 33 deletions tpu_inference/models/jax/llama3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from itertools import islice
from typing import List, Optional, Tuple

import jax
Expand All @@ -6,13 +7,17 @@
from jax.sharding import Mesh
from transformers import LlamaConfig, modeling_flax_utils
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group

from tpu_inference import utils
from tpu_inference.layers.common.attention_interface import attention
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.common.sharding import ShardingAxisName
from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
from tpu_inference.layers.jax.rope_interface import apply_rope
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.jax_intermediate_tensor import \
JaxIntermediateTensors
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
load_hf_weights)

Expand Down Expand Up @@ -235,38 +240,52 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
rms_norm_eps = hf_config.rms_norm_eps
hidden_size = hf_config.hidden_size

self.embed = nnx.Embed(
num_embeddings=vocab_size,
features=hidden_size,
param_dtype=dtype,
embedding_init=nnx.with_partitioning(
init_fn, (ShardingAxisName.VOCAB, None)),
rngs=rng,
)
self.layers = [
LlamaDecoderLayer(
self.is_first_rank = get_pp_group().is_first_rank
self.is_last_rank = get_pp_group().is_last_rank

if self.is_first_rank or (hf_config.tie_word_embeddings
and self.is_last_rank):
self.embed = nnx.Embed(
num_embeddings=vocab_size,
features=hidden_size,
param_dtype=dtype,
embedding_init=nnx.with_partitioning(
init_fn, (ShardingAxisName.VOCAB, None)),
rngs=rng,
)
else:
self.embed = PPMissingLayer()

self.start_layer, self.end_layer, self.layers = make_layers(
hf_config.num_hidden_layers,
lambda: LlamaDecoderLayer(
config=hf_config,
dtype=dtype,
rng=rng,
mesh=mesh,
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
for _ in range(hf_config.num_hidden_layers)
]
self.norm = nnx.RMSNorm(
hidden_size,
epsilon=rms_norm_eps,
param_dtype=dtype,
scale_init=nnx.with_partitioning(init_fn, (None, )),
rngs=rng,
)
if model_config.hf_config.tie_word_embeddings:
self.lm_head = self.embed.embedding
else:
self.lm_head = nnx.Param(
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
sharding=(None, ShardingAxisName.VOCAB),
kv_cache_dtype=vllm_config.cache_config.cache_dtype))
if self.is_last_rank:
self.norm = nnx.RMSNorm(
hidden_size,
epsilon=rms_norm_eps,
param_dtype=dtype,
scale_init=nnx.with_partitioning(init_fn, (None, )),
rngs=rng,
)
else:
self.norm = PPMissingLayer()

if self.is_last_rank:
if model_config.hf_config.tie_word_embeddings:
self.lm_head = self.embed.embedding
else:
self.lm_head = nnx.Param(
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
sharding=(None, ShardingAxisName.VOCAB),
)
else:
self.lm_head = PPMissingLayer()

self.aux_hidden_state_layers = []
if vllm_config.speculative_config and vllm_config.speculative_config.method == "eagle3":
Expand All @@ -282,10 +301,18 @@ def __call__(
kv_caches: List[jax.Array],
input_ids: jax.Array,
attention_metadata: AttentionMetadata,
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
x = self.embed(input_ids)
intermediate_tensors: JaxIntermediateTensors | None,
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
List[jax.Array], JaxIntermediateTensors]:
if self.is_first_rank:
x = self.embed(input_ids)
else:
assert intermediate_tensors is not None
x = intermediate_tensors["hidden_states"]

aux_hidden_states = []
for i, layer in enumerate(self.layers):
for i, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)):
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(x)
kv_cache = kv_caches[i]
Expand All @@ -295,6 +322,10 @@ def __call__(
attention_metadata,
)
kv_caches[i] = kv_cache
if not self.is_last_rank:
# Note: add aux_hidden_states to make the output spec consistent.
return kv_caches, JaxIntermediateTensors({"hidden_states":
x}), aux_hidden_states
x = self.norm(x)
return kv_caches, x, aux_hidden_states

Expand All @@ -313,19 +344,32 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
mesh=mesh,
)

self.pp_missing_layers = []
for path, module in nnx.iter_graph(self.model):
if isinstance(module, PPMissingLayer):
# the path should be sth like ('layers', '0')
self.pp_missing_layers.append('.'.join([str(s) for s in path]))

def __call__(
self,
kv_caches: List[jax.Array],
input_ids: jax.Array,
attention_metadata: AttentionMetadata,
_input_embeds,
_layer_name_to_kv_cache,
_lora_metadata,
intermediate_tensors: JaxIntermediateTensors,
_is_first_rank: bool,
_is_last_rank: bool,
*args,
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
kv_caches, x, aux_hidden_states = self.model(
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
List[jax.Array], JaxIntermediateTensors]:
return self.model(
kv_caches,
input_ids,
attention_metadata,
intermediate_tensors,
)
return kv_caches, x, aux_hidden_states

def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
if self.vllm_config.model_config.hf_config.tie_word_embeddings:
Expand Down Expand Up @@ -372,4 +416,5 @@ def load_weights(self, rng_key: jax.Array):
load_hf_weights(vllm_config=self.vllm_config,
model=self,
metadata_map=metadata_map,
mesh=self.mesh)
mesh=self.mesh,
pp_missing_layers=self.pp_missing_layers)
21 changes: 18 additions & 3 deletions tpu_inference/models/jax/utils/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def _load_hf_weights_on_thread(vllm_config,
weights_file: str,
filter_regex: str | None = None,
keep_original_dtype_keys_regex: list[str]
| None = None):
| None = None,
pp_missing_layers: list[str] | None = None):
name_map = metadata_map.name_map
reshape_keys = metadata_map.reshape_map
bias_reshape_keys = metadata_map.bias_reshape_map
Expand Down Expand Up @@ -338,6 +339,17 @@ def _load_hf_weights_on_thread(vllm_config,
)
continue
model_key = name_map.get(hf_key, hf_key)
# add skip pp missing layers.
def is_pp_missing_layer(hf_key):
has_digit = any(char.isdigit() for char in hf_key)
# add the suffix after digits to avoid it matches "layers.10" with "layers.1"
suffix = "." if has_digit else ""
return any(f'{pp_missing_layer}{suffix}' in hf_key
for pp_missing_layer in pp_missing_layers)

if pp_missing_layers and is_pp_missing_layer(hf_key):
continue

model_weight, model_sharding = get_param_and_sharding(
params, shardings, model_key)

Expand Down Expand Up @@ -408,14 +420,16 @@ def load_hf_weights(vllm_config,
mesh: Mesh,
filter_regex: str | None = None,
is_draft_model: bool = False,
keep_original_dtype_keys_regex: list[str] | None = None):
keep_original_dtype_keys_regex: list[str] | None = None,
pp_missing_layers: list[str] | None = None):
"""Load weights from all model weights files to the model, run in multi threads."""
if is_draft_model:
model_path = vllm_config.speculative_config.draft_model_config.model
else:
model_path = vllm_config.model_config.model
weights_files = get_model_weights_files(
model_path, vllm_config.load_config.download_dir)
# For PP, params are partial.
params = nnx.state(model)
max_workers = min(64, len(weights_files))
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
Expand All @@ -433,7 +447,8 @@ def load_hf_weights(vllm_config,
mesh,
weights_file,
filter_regex=filter_regex,
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
pp_missing_layers=pp_missing_layers)
for weights_file in weights_files
]
for future in futures:
Expand Down