diff --git a/tpu_inference/layers/jax/pp_utils.py b/tpu_inference/layers/jax/pp_utils.py new file mode 100644 index 000000000..1f8299522 --- /dev/null +++ b/tpu_inference/layers/jax/pp_utils.py @@ -0,0 +1,39 @@ +from typing import List, Protocol + +from flax import nnx +from vllm.distributed import get_pp_group +from vllm.distributed.utils import get_pp_indices + + +class PPMissingLayer(nnx.Module): + """ + A placeholder layer for missing layers in a pipeline parallel model. + """ + + def __init__(self, *args, **kwargs): + pass + + def __call__(self, *args, **kwargs): + """Return the first arg from args or the first value from kwargs.""" + return args[0] if args else next(iter(kwargs.values())) + + +class LayerFn(Protocol): + + def __call__(self) -> nnx.Module: + ... + + +def make_layers( + num_hidden_layers: int, + layer_fn: LayerFn, +) -> tuple[int, int, List[nnx.Module]]: + start_layer, end_layer = get_pp_indices(num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + + layers = [PPMissingLayer() for _ in range(start_layer)] \ + + [layer_fn() for _ in range(start_layer, end_layer)] \ + + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)] + + return start_layer, end_layer, layers diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index 32965676f..7f7f71afa 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -217,7 +217,9 @@ def get_flax_model( hidden_states_sharding, # aux hidden states ), donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache - static_argnums=6, #6 is layer_name_to_kvcache_index + static_argnums=( + 6, 9, 10 + ), #6 is layer_name_to_kvcache_index, 9 is is_first_rank, 10 is is_last_rank ) def run_model(graphdef, state, *args): model = nnx.merge(graphdef, state) diff --git a/tpu_inference/models/jax/llama3.py b/tpu_inference/models/jax/llama3.py index 4a25e4c9a..6d6d3a3c3 100644 --- a/tpu_inference/models/jax/llama3.py +++ b/tpu_inference/models/jax/llama3.py @@ -1,3 +1,4 @@ +from itertools import islice from typing import List, Optional, Tuple import jax @@ -6,13 +7,17 @@ from jax.sharding import Mesh from transformers import LlamaConfig, modeling_flax_utils from vllm.config import VllmConfig +from vllm.distributed import get_pp_group from tpu_inference import utils from tpu_inference.layers.common.attention_interface import attention from tpu_inference.layers.common.attention_metadata import AttentionMetadata from tpu_inference.layers.common.sharding import ShardingAxisName +from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers from tpu_inference.layers.jax.rope_interface import apply_rope from tpu_inference.logger import init_logger +from tpu_inference.models.jax.jax_intermediate_tensor import \ + JaxIntermediateTensors from tpu_inference.models.jax.utils.weight_utils import (get_default_maps, load_hf_weights) @@ -235,38 +240,52 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, rms_norm_eps = hf_config.rms_norm_eps hidden_size = hf_config.hidden_size - self.embed = nnx.Embed( - num_embeddings=vocab_size, - features=hidden_size, - param_dtype=dtype, - embedding_init=nnx.with_partitioning( - init_fn, (ShardingAxisName.VOCAB, None)), - rngs=rng, - ) - self.layers = [ - LlamaDecoderLayer( + self.is_first_rank = get_pp_group().is_first_rank + self.is_last_rank = get_pp_group().is_last_rank + + if self.is_first_rank or (hf_config.tie_word_embeddings + and self.is_last_rank): + self.embed = nnx.Embed( + num_embeddings=vocab_size, + features=hidden_size, + param_dtype=dtype, + embedding_init=nnx.with_partitioning( + init_fn, (ShardingAxisName.VOCAB, None)), + rngs=rng, + ) + else: + self.embed = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + hf_config.num_hidden_layers, + lambda: LlamaDecoderLayer( config=hf_config, dtype=dtype, rng=rng, mesh=mesh, # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly - kv_cache_dtype=vllm_config.cache_config.cache_dtype) - for _ in range(hf_config.num_hidden_layers) - ] - self.norm = nnx.RMSNorm( - hidden_size, - epsilon=rms_norm_eps, - param_dtype=dtype, - scale_init=nnx.with_partitioning(init_fn, (None, )), - rngs=rng, - ) - if model_config.hf_config.tie_word_embeddings: - self.lm_head = self.embed.embedding - else: - self.lm_head = nnx.Param( - init_fn(rng.params(), (hidden_size, vocab_size), dtype), - sharding=(None, ShardingAxisName.VOCAB), + kv_cache_dtype=vllm_config.cache_config.cache_dtype)) + if self.is_last_rank: + self.norm = nnx.RMSNorm( + hidden_size, + epsilon=rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, ) + else: + self.norm = PPMissingLayer() + + if self.is_last_rank: + if model_config.hf_config.tie_word_embeddings: + self.lm_head = self.embed.embedding + else: + self.lm_head = nnx.Param( + init_fn(rng.params(), (hidden_size, vocab_size), dtype), + sharding=(None, ShardingAxisName.VOCAB), + ) + else: + self.lm_head = PPMissingLayer() self.aux_hidden_state_layers = [] if vllm_config.speculative_config and vllm_config.speculative_config.method == "eagle3": @@ -282,10 +301,18 @@ def __call__( kv_caches: List[jax.Array], input_ids: jax.Array, attention_metadata: AttentionMetadata, - ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]: - x = self.embed(input_ids) + intermediate_tensors: JaxIntermediateTensors | None, + ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[ + List[jax.Array], JaxIntermediateTensors]: + if self.is_first_rank: + x = self.embed(input_ids) + else: + assert intermediate_tensors is not None + x = intermediate_tensors["hidden_states"] + aux_hidden_states = [] - for i, layer in enumerate(self.layers): + for i, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer)): if i in self.aux_hidden_state_layers: aux_hidden_states.append(x) kv_cache = kv_caches[i] @@ -295,6 +322,10 @@ def __call__( attention_metadata, ) kv_caches[i] = kv_cache + if not self.is_last_rank: + # Note: add aux_hidden_states to make the output spec consistent. + return kv_caches, JaxIntermediateTensors({"hidden_states": + x}), aux_hidden_states x = self.norm(x) return kv_caches, x, aux_hidden_states @@ -313,19 +344,32 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh=mesh, ) + self.pp_missing_layers = [] + for path, module in nnx.iter_graph(self.model): + if isinstance(module, PPMissingLayer): + # the path should be sth like ('layers', '0') + self.pp_missing_layers.append('.'.join([str(s) for s in path])) + def __call__( self, kv_caches: List[jax.Array], input_ids: jax.Array, attention_metadata: AttentionMetadata, + _input_embeds, + _layer_name_to_kv_cache, + _lora_metadata, + intermediate_tensors: JaxIntermediateTensors, + _is_first_rank: bool, + _is_last_rank: bool, *args, - ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]: - kv_caches, x, aux_hidden_states = self.model( + ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[ + List[jax.Array], JaxIntermediateTensors]: + return self.model( kv_caches, input_ids, attention_metadata, + intermediate_tensors, ) - return kv_caches, x, aux_hidden_states def compute_logits(self, hidden_states: jax.Array) -> jax.Array: if self.vllm_config.model_config.hf_config.tie_word_embeddings: @@ -372,4 +416,5 @@ def load_weights(self, rng_key: jax.Array): load_hf_weights(vllm_config=self.vllm_config, model=self, metadata_map=metadata_map, - mesh=self.mesh) + mesh=self.mesh, + pp_missing_layers=self.pp_missing_layers) diff --git a/tpu_inference/models/jax/utils/weight_utils.py b/tpu_inference/models/jax/utils/weight_utils.py index 64f026dae..a7c12a619 100644 --- a/tpu_inference/models/jax/utils/weight_utils.py +++ b/tpu_inference/models/jax/utils/weight_utils.py @@ -273,7 +273,8 @@ def _load_hf_weights_on_thread(vllm_config, weights_file: str, filter_regex: str | None = None, keep_original_dtype_keys_regex: list[str] - | None = None): + | None = None, + pp_missing_layers: list[str] | None = None): name_map = metadata_map.name_map reshape_keys = metadata_map.reshape_map bias_reshape_keys = metadata_map.bias_reshape_map @@ -338,6 +339,17 @@ def _load_hf_weights_on_thread(vllm_config, ) continue model_key = name_map.get(hf_key, hf_key) + # add skip pp missing layers. + def is_pp_missing_layer(hf_key): + has_digit = any(char.isdigit() for char in hf_key) + # add the suffix after digits to avoid it matches "layers.10" with "layers.1" + suffix = "." if has_digit else "" + return any(f'{pp_missing_layer}{suffix}' in hf_key + for pp_missing_layer in pp_missing_layers) + + if pp_missing_layers and is_pp_missing_layer(hf_key): + continue + model_weight, model_sharding = get_param_and_sharding( params, shardings, model_key) @@ -408,7 +420,8 @@ def load_hf_weights(vllm_config, mesh: Mesh, filter_regex: str | None = None, is_draft_model: bool = False, - keep_original_dtype_keys_regex: list[str] | None = None): + keep_original_dtype_keys_regex: list[str] | None = None, + pp_missing_layers: list[str] | None = None): """Load weights from all model weights files to the model, run in multi threads.""" if is_draft_model: model_path = vllm_config.speculative_config.draft_model_config.model @@ -416,6 +429,7 @@ def load_hf_weights(vllm_config, model_path = vllm_config.model_config.model weights_files = get_model_weights_files( model_path, vllm_config.load_config.download_dir) + # For PP, params are partial. params = nnx.state(model) max_workers = min(64, len(weights_files)) # NOTE(xiang): Disable multi-threading mode if running on multi-host. @@ -433,7 +447,8 @@ def load_hf_weights(vllm_config, mesh, weights_file, filter_regex=filter_regex, - keep_original_dtype_keys_regex=keep_original_dtype_keys_regex) + keep_original_dtype_keys_regex=keep_original_dtype_keys_regex, + pp_missing_layers=pp_missing_layers) for weights_file in weights_files ] for future in futures: