diff --git a/tests/e2e/test_data_parallel.py b/tests/e2e/test_data_parallel.py index 79c1a56e8..9d794df29 100644 --- a/tests/e2e/test_data_parallel.py +++ b/tests/e2e/test_data_parallel.py @@ -9,12 +9,6 @@ from vllm import LLM, EngineArgs, SamplingParams -@pytest.fixture -def model_name(): - """Small model for faster testing.""" - return "Qwen/Qwen2.5-1.5B-Instruct" - - @pytest.fixture(autouse=True) def setup_new_model_design(): """Automatically set NEW_MODEL_DESIGN=True for all tests.""" @@ -56,21 +50,23 @@ def _run_inference_with_config(model_name: str, data_parallel_size: int = 1, additional_config: dict = {}, kv_cache_dtype: str = "auto", - enable_prefix_caching: bool = False) -> list: + enable_prefix_caching: bool = False, + async_scheduling: bool = False) -> list: """Helper function to run inference with specified configuration.""" # Create LLM args using parser-based approach similar to offline_inference.py engine_args = EngineArgs( model=model_name, - max_model_len=128, + max_model_len=32, tensor_parallel_size=tensor_parallel_size, data_parallel_size=data_parallel_size, - gpu_memory_utilization=0.95, + gpu_memory_utilization=0.98, max_num_batched_tokens=128, max_num_seqs=16, enable_prefix_caching=enable_prefix_caching, additional_config=additional_config, kv_cache_dtype=kv_cache_dtype, + async_scheduling=async_scheduling, ) engine_args_dict = asdict(engine_args) @@ -86,7 +82,6 @@ def _run_inference_with_config(model_name: str, def test_model_data_parallelism( - model_name: str, test_prompts: list, sampling_params: SamplingParams, ): @@ -98,9 +93,12 @@ def test_model_data_parallelism( Equivalent to: python examples/offline_inference.py --tensor_parallel_size=4 --data_parallel_size=2 """ + # Use Llama 1B for this test + test_model = "meta-llama/Llama-3.2-1B-Instruct" + # Test with data parallelism enabled outputs = _run_inference_with_config( - model_name=model_name, + model_name=test_model, test_prompts=test_prompts, sampling_params=sampling_params, tensor_parallel_size=1, @@ -119,7 +117,6 @@ def test_model_data_parallelism( def test_attention_data_parallelism( - model_name: str, test_prompts: list, sampling_params: SamplingParams, ): @@ -132,6 +129,9 @@ def test_attention_data_parallelism( python examples/offline_inference.py --tensor_parallel_size=8 --kv-cache-dtype=fp8 \ --additional_config='{"sharding":{"sharding_strategy": {"enable_dp_attention":1}}}' """ + # Use Llama 1B for this test + test_model = "Qwen/Qwen3-0.6B" + additional_config = { "sharding": { "sharding_strategy": { @@ -142,7 +142,7 @@ def test_attention_data_parallelism( # Test with attention data parallelism enabled outputs = _run_inference_with_config( - model_name=model_name, + model_name=test_model, test_prompts=test_prompts, sampling_params=sampling_params, tensor_parallel_size=8, @@ -165,7 +165,6 @@ def test_attention_data_parallelism( def test_data_parallelism_correctness( - model_name: str, test_prompts: list, sampling_params: SamplingParams, ): @@ -176,7 +175,7 @@ def test_data_parallelism_correctness( """ os.environ['SKIP_JAX_PRECOMPILE'] = '1' os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0' - + model_name = "Qwen/Qwen2.5-1.5B-Instruct" # Use a smaller subset of prompts for correctness testing small_prompts = test_prompts[:10] @@ -187,6 +186,7 @@ def test_data_parallelism_correctness( sampling_params=sampling_params, tensor_parallel_size=1, data_parallel_size=1, + async_scheduling=True, ) # Run with model data parallelism and async scheduling @@ -196,9 +196,7 @@ def test_data_parallelism_correctness( sampling_params=sampling_params, tensor_parallel_size=1, data_parallel_size=2, - additional_config={"scheduler_config": { - "async_scheduling": True - }}, + async_scheduling=True, ) # Compare outputs - they should be identical for greedy sampling diff --git a/tests/runner/test_tpu_runner_dp.py b/tests/runner/test_tpu_runner_dp.py index 07e76f91b..2fbd1e33e 100644 --- a/tests/runner/test_tpu_runner_dp.py +++ b/tests/runner/test_tpu_runner_dp.py @@ -102,8 +102,8 @@ def test_prepare_inputs_dp_basic_functionality(self, result = self.runner._prepare_inputs_dp(scheduler_output) # Basic assertions - assert len(result) == 6 - input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector = result + assert len(result) == 7 + input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result # Verify utility functions were called mock_runner_utils.get_padded_token_len.assert_called() @@ -380,8 +380,7 @@ def mock_get_padded_token_len(paddings_list, val): # Execute the method result = self.runner._prepare_inputs_dp(scheduler_output) - input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector = result - + input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result # 1. Verify input_ids content expected_input_ids = np.zeros(16, dtype=np.int32) expected_input_ids[:2] = [1006, 1007] @@ -495,7 +494,7 @@ def mock_get_padded_token_len(paddings_list, val): # Execute the method result = self.runner._prepare_inputs_dp(scheduler_output) - input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector = result + input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result # 1. Verify input_ids expected_input_ids = np.zeros(16, dtype=np.int32) @@ -724,7 +723,7 @@ def test_prepare_inputs_routing_to_non_dp(self): self.runner.dp_size = 1 self.runner._prepare_inputs_non_dp = MagicMock( - return_value=(None, None, None, None, None, None)) + return_value=(None, None, None, None, None, None, None)) scheduler_output = MagicMock() self.runner._prepare_inputs(scheduler_output) diff --git a/tpu_inference/runner/kv_cache_manager.py b/tpu_inference/runner/kv_cache_manager.py index bd6932fd8..dcb94a966 100644 --- a/tpu_inference/runner/kv_cache_manager.py +++ b/tpu_inference/runner/kv_cache_manager.py @@ -1,5 +1,4 @@ import functools -import math from typing import TYPE_CHECKING, Dict, List import jax @@ -190,7 +189,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: num_blocks = kv_cache_tensor.size // page_size_bytes dp_size = self.runner.vllm_config.sharding_config.total_dp_size # num_blocks must be a multiple of dp_size - num_blocks = math.ceil(num_blocks / dp_size) * dp_size + num_blocks = (num_blocks // dp_size) * dp_size # NOTE: we'll multiply the num_kv_heads by 2 in the function kv_cache = create_kv_caches( num_blocks=num_blocks, diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 3841e7460..f3c6d7899 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -15,7 +15,7 @@ from flax import nnx from jax.experimental import mesh_utils from jax.sharding import NamedSharding, PartitionSpec -from torchax.ops.mappings import j2t, j2t_dtype +from torchax.ops.mappings import j2t_dtype from vllm.config import VllmConfig from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -154,6 +154,7 @@ class ExecuteModelState: spec_decode_metadata: Optional[SpecDecodeMetadata] kv_connector_output: Optional[KVConnectorOutput] logits_indices_selector: Optional[List[int]] = None + padded_num_reqs: Optional[int] = None @functools.partial(jax.jit, donate_argnums=(0, 1, 2)) @@ -191,19 +192,28 @@ def _substitute_placeholder_token( return input_ids.at[token_in_tpu_cur_input_indices].set(update_values) -def _reorder_logits_indices(logprobs_lists: LogprobsLists, - logits_indices_selector: List[int]): +def _jax_logprobs_to_lists(logprobs_tensors, + logits_indices_selector=None, + cu_num_generated_tokens=None): + """Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy.""" + log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist() + logprobs_list = logprobs_tensors.logprobs.tolist() + selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist() + + if logits_indices_selector is not None: + log_token_ids_list = [ + log_token_ids_list[i] for i in logits_indices_selector + ] + logprobs_list = [logprobs_list[i] for i in logits_indices_selector] + selected_token_ranks_list = [ + selected_token_ranks_list[i] for i in logits_indices_selector + ] + return LogprobsLists( - logprob_token_ids=[ - logprobs_lists.logprob_token_ids[i] - for i in logits_indices_selector - ], - logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector], - sampled_token_ranks=[ - logprobs_lists.sampled_token_ranks[i] - for i in logits_indices_selector - ], - cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens, + logprob_token_ids=np.asarray(log_token_ids_list), + logprobs=np.asarray(logprobs_list), + sampled_token_ranks=np.asarray(selected_token_ranks_list), + cu_num_generated_tokens=cu_num_generated_tokens, ) @@ -552,16 +562,17 @@ def sample_tokens( (scheduler_output, attn_metadata, input_ids, hidden_states, logits, aux_hidden_states, spec_decode_metadata, kv_connector_output, - logits_indices_selector) = ( - self.execute_model_state.scheduler_output, - self.execute_model_state.attn_metadata, - self.execute_model_state.input_ids, - self.execute_model_state.hidden_states, - self.execute_model_state.logits, - self.execute_model_state.aux_hidden_states, - self.execute_model_state.spec_decode_metadata, - self.execute_model_state.kv_connector_output, - self.execute_model_state.logits_indices_selector) + logits_indices_selector, + padded_num_reqs) = (self.execute_model_state.scheduler_output, + self.execute_model_state.attn_metadata, + self.execute_model_state.input_ids, + self.execute_model_state.hidden_states, + self.execute_model_state.logits, + self.execute_model_state.aux_hidden_states, + self.execute_model_state.spec_decode_metadata, + self.execute_model_state.kv_connector_output, + self.execute_model_state.logits_indices_selector, + self.execute_model_state.padded_num_reqs) self.execute_model_state = None if grammar_output is not None: @@ -575,12 +586,10 @@ def sample_tokens( logits, arange, ) - return self._sample_from_logits(scheduler_output, attn_metadata, - input_ids, hidden_states, logits, - aux_hidden_states, - spec_decode_metadata, - kv_connector_output, - logits_indices_selector) + return self._sample_from_logits( + scheduler_output, attn_metadata, input_ids, hidden_states, logits, + aux_hidden_states, spec_decode_metadata, kv_connector_output, + logits_indices_selector, padded_num_reqs) def _modify_prev_results(self): # If copy to host has not been done, we just wait. @@ -694,6 +703,7 @@ def _execute_model( logits_indices, spec_decode_metadata, logits_indices_selector, + padded_num_reqs, ) = self._prepare_inputs(scheduler_output) # multi-modal support @@ -756,7 +766,8 @@ def _execute_model( aux_hidden_states=aux_hidden_states, spec_decode_metadata=spec_decode_metadata, kv_connector_output=kv_connector_output, - logits_indices_selector=logits_indices_selector) + logits_indices_selector=logits_indices_selector, + padded_num_reqs=padded_num_reqs) return attn_metadata, None def _sample_from_logits( @@ -770,11 +781,19 @@ def _sample_from_logits( spec_decode_metadata: Optional[SpecDecodeMetadata], kv_connector_output: Optional[KVConnectorOutput], logits_indices_selector: Optional[List[int]] = None, + padded_num_reqs: Optional[int] = None, ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput: - padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit( - self.input_batch.num_reqs, self.max_num_reqs) + if padded_num_reqs is None: + padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit( + self.input_batch.num_reqs, self.max_num_reqs) + + sharding = None + if self.dp_size > 1: + sharding = NamedSharding(self.mesh, + PartitionSpec(ShardingAxisName.ATTN_DATA)) + tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( - self.mesh, self.input_batch, padded_num_reqs) + self.mesh, self.input_batch, padded_num_reqs, sharding=sharding) if spec_decode_metadata is None: next_tokens = sample( self.rng_params_for_sampling, @@ -806,8 +825,6 @@ def _sample_from_logits( if tpu_sampling_metadata.logprobs: logprobs = self._compute_and_gather_logprobs( logits, next_tokens, self.model_config.max_logprobs) - logprobs_lists = jax.tree.map(lambda x: j2t(x.astype(jnp.float32)), - logprobs).tolists() else: logprobs = None @@ -860,9 +877,8 @@ def _sample_from_logits( if logprobs is not None: # Map logprobs back to the pre-dp shuffling order - if logits_indices_selector is not None: - logprobs_lists = _reorder_logits_indices( - logprobs_lists, logits_indices_selector) + logprobs_lists = _jax_logprobs_to_lists( + logprobs, logits_indices_selector) else: logprobs_lists = None @@ -934,9 +950,8 @@ def _sample_from_logits( if logprobs is not None: # Map logprobs back to the pre-dp shuffling order - if logits_indices_selector is not None: - logprobs_lists = _reorder_logits_indices( - logprobs_lists, logits_indices_selector) + logprobs_lists = _jax_logprobs_to_lists(logprobs, + logits_indices_selector) else: logprobs_lists = None @@ -1397,6 +1412,7 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"): logits_indices, spec_decode_metadata, logits_indices_selector, + padded_num_reqs, ) def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): @@ -1563,7 +1579,8 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"): attention_metadata.seq_lens_cpu = seq_lens_cpu logits_indices_selector = None return (input_ids, attention_metadata, sampling_metadata, - logits_indices, spec_decode_metadata, logits_indices_selector) + logits_indices, spec_decode_metadata, logits_indices_selector, + padded_num_reqs) def _get_input_ids_embeds(self, input_ids: jax.Array, mm_embeds: list[jax.Array]):