Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions tests/e2e/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@
from vllm import LLM, EngineArgs, SamplingParams


@pytest.fixture
def model_name():
"""Small model for faster testing."""
return "Qwen/Qwen2.5-1.5B-Instruct"


@pytest.fixture(autouse=True)
def setup_new_model_design():
"""Automatically set NEW_MODEL_DESIGN=True for all tests."""
Expand Down Expand Up @@ -56,21 +50,23 @@ def _run_inference_with_config(model_name: str,
data_parallel_size: int = 1,
additional_config: dict = {},
kv_cache_dtype: str = "auto",
enable_prefix_caching: bool = False) -> list:
enable_prefix_caching: bool = False,
async_scheduling: bool = False) -> list:
"""Helper function to run inference with specified configuration."""

# Create LLM args using parser-based approach similar to offline_inference.py
engine_args = EngineArgs(
model=model_name,
max_model_len=128,
max_model_len=32,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
gpu_memory_utilization=0.95,
gpu_memory_utilization=0.98,
max_num_batched_tokens=128,
max_num_seqs=16,
enable_prefix_caching=enable_prefix_caching,
additional_config=additional_config,
kv_cache_dtype=kv_cache_dtype,
async_scheduling=async_scheduling,
)

engine_args_dict = asdict(engine_args)
Expand All @@ -86,7 +82,6 @@ def _run_inference_with_config(model_name: str,


def test_model_data_parallelism(
model_name: str,
test_prompts: list,
sampling_params: SamplingParams,
):
Expand All @@ -98,9 +93,12 @@ def test_model_data_parallelism(
Equivalent to:
python examples/offline_inference.py --tensor_parallel_size=4 --data_parallel_size=2
"""
# Use Llama 1B for this test
test_model = "meta-llama/Llama-3.2-1B-Instruct"

# Test with data parallelism enabled
outputs = _run_inference_with_config(
model_name=model_name,
model_name=test_model,
test_prompts=test_prompts,
sampling_params=sampling_params,
tensor_parallel_size=1,
Expand All @@ -119,7 +117,6 @@ def test_model_data_parallelism(


def test_attention_data_parallelism(
model_name: str,
test_prompts: list,
sampling_params: SamplingParams,
):
Expand All @@ -132,6 +129,9 @@ def test_attention_data_parallelism(
python examples/offline_inference.py --tensor_parallel_size=8 --kv-cache-dtype=fp8 \
--additional_config='{"sharding":{"sharding_strategy": {"enable_dp_attention":1}}}'
"""
# Use Llama 1B for this test
test_model = "Qwen/Qwen3-0.6B"

additional_config = {
"sharding": {
"sharding_strategy": {
Expand All @@ -142,7 +142,7 @@ def test_attention_data_parallelism(

# Test with attention data parallelism enabled
outputs = _run_inference_with_config(
model_name=model_name,
model_name=test_model,
test_prompts=test_prompts,
sampling_params=sampling_params,
tensor_parallel_size=8,
Expand All @@ -165,7 +165,6 @@ def test_attention_data_parallelism(


def test_data_parallelism_correctness(
model_name: str,
test_prompts: list,
sampling_params: SamplingParams,
):
Expand All @@ -176,7 +175,7 @@ def test_data_parallelism_correctness(
"""
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'

model_name = "Qwen/Qwen2.5-1.5B-Instruct"
# Use a smaller subset of prompts for correctness testing
small_prompts = test_prompts[:10]

Expand All @@ -187,6 +186,7 @@ def test_data_parallelism_correctness(
sampling_params=sampling_params,
tensor_parallel_size=1,
data_parallel_size=1,
async_scheduling=True,
)

# Run with model data parallelism and async scheduling
Expand All @@ -196,9 +196,7 @@ def test_data_parallelism_correctness(
sampling_params=sampling_params,
tensor_parallel_size=1,
data_parallel_size=2,
additional_config={"scheduler_config": {
"async_scheduling": True
}},
async_scheduling=True,
)

# Compare outputs - they should be identical for greedy sampling
Expand Down
11 changes: 5 additions & 6 deletions tests/runner/test_tpu_runner_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def test_prepare_inputs_dp_basic_functionality(self,
result = self.runner._prepare_inputs_dp(scheduler_output)

# Basic assertions
assert len(result) == 6
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector = result
assert len(result) == 7
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result

# Verify utility functions were called
mock_runner_utils.get_padded_token_len.assert_called()
Expand Down Expand Up @@ -380,8 +380,7 @@ def mock_get_padded_token_len(paddings_list, val):

# Execute the method
result = self.runner._prepare_inputs_dp(scheduler_output)
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector = result

input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
# 1. Verify input_ids content
expected_input_ids = np.zeros(16, dtype=np.int32)
expected_input_ids[:2] = [1006, 1007]
Expand Down Expand Up @@ -495,7 +494,7 @@ def mock_get_padded_token_len(paddings_list, val):

# Execute the method
result = self.runner._prepare_inputs_dp(scheduler_output)
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector = result
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result

# 1. Verify input_ids
expected_input_ids = np.zeros(16, dtype=np.int32)
Expand Down Expand Up @@ -724,7 +723,7 @@ def test_prepare_inputs_routing_to_non_dp(self):

self.runner.dp_size = 1
self.runner._prepare_inputs_non_dp = MagicMock(
return_value=(None, None, None, None, None, None))
return_value=(None, None, None, None, None, None, None))

scheduler_output = MagicMock()
self.runner._prepare_inputs(scheduler_output)
Expand Down
3 changes: 1 addition & 2 deletions tpu_inference/runner/kv_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import math
from typing import TYPE_CHECKING, Dict, List

import jax
Expand Down Expand Up @@ -190,7 +189,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
num_blocks = kv_cache_tensor.size // page_size_bytes
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
# num_blocks must be a multiple of dp_size
num_blocks = math.ceil(num_blocks / dp_size) * dp_size
num_blocks = (num_blocks // dp_size) * dp_size
# NOTE: we'll multiply the num_kv_heads by 2 in the function
kv_cache = create_kv_caches(
num_blocks=num_blocks,
Expand Down
101 changes: 59 additions & 42 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flax import nnx
from jax.experimental import mesh_utils
from jax.sharding import NamedSharding, PartitionSpec
from torchax.ops.mappings import j2t, j2t_dtype
from torchax.ops.mappings import j2t_dtype
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
Expand Down Expand Up @@ -154,6 +154,7 @@ class ExecuteModelState:
spec_decode_metadata: Optional[SpecDecodeMetadata]
kv_connector_output: Optional[KVConnectorOutput]
logits_indices_selector: Optional[List[int]] = None
padded_num_reqs: Optional[int] = None


@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
Expand Down Expand Up @@ -191,19 +192,28 @@ def _substitute_placeholder_token(
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)


def _reorder_logits_indices(logprobs_lists: LogprobsLists,
logits_indices_selector: List[int]):
def _jax_logprobs_to_lists(logprobs_tensors,
logits_indices_selector=None,
cu_num_generated_tokens=None):
"""Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
logprobs_list = logprobs_tensors.logprobs.tolist()
selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()

if logits_indices_selector is not None:
log_token_ids_list = [
log_token_ids_list[i] for i in logits_indices_selector
]
logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
selected_token_ranks_list = [
selected_token_ranks_list[i] for i in logits_indices_selector
]

return LogprobsLists(
logprob_token_ids=[
logprobs_lists.logprob_token_ids[i]
for i in logits_indices_selector
],
logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
sampled_token_ranks=[
logprobs_lists.sampled_token_ranks[i]
for i in logits_indices_selector
],
cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
logprob_token_ids=np.asarray(log_token_ids_list),
logprobs=np.asarray(logprobs_list),
sampled_token_ranks=np.asarray(selected_token_ranks_list),
cu_num_generated_tokens=cu_num_generated_tokens,
)


Expand Down Expand Up @@ -552,16 +562,17 @@ def sample_tokens(

(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
aux_hidden_states, spec_decode_metadata, kv_connector_output,
logits_indices_selector) = (
self.execute_model_state.scheduler_output,
self.execute_model_state.attn_metadata,
self.execute_model_state.input_ids,
self.execute_model_state.hidden_states,
self.execute_model_state.logits,
self.execute_model_state.aux_hidden_states,
self.execute_model_state.spec_decode_metadata,
self.execute_model_state.kv_connector_output,
self.execute_model_state.logits_indices_selector)
logits_indices_selector,
padded_num_reqs) = (self.execute_model_state.scheduler_output,
self.execute_model_state.attn_metadata,
self.execute_model_state.input_ids,
self.execute_model_state.hidden_states,
self.execute_model_state.logits,
self.execute_model_state.aux_hidden_states,
self.execute_model_state.spec_decode_metadata,
self.execute_model_state.kv_connector_output,
self.execute_model_state.logits_indices_selector,
self.execute_model_state.padded_num_reqs)
self.execute_model_state = None

if grammar_output is not None:
Expand All @@ -575,12 +586,10 @@ def sample_tokens(
logits,
arange,
)
return self._sample_from_logits(scheduler_output, attn_metadata,
input_ids, hidden_states, logits,
aux_hidden_states,
spec_decode_metadata,
kv_connector_output,
logits_indices_selector)
return self._sample_from_logits(
scheduler_output, attn_metadata, input_ids, hidden_states, logits,
aux_hidden_states, spec_decode_metadata, kv_connector_output,
logits_indices_selector, padded_num_reqs)

def _modify_prev_results(self):
# If copy to host has not been done, we just wait.
Expand Down Expand Up @@ -694,6 +703,7 @@ def _execute_model(
logits_indices,
spec_decode_metadata,
logits_indices_selector,
padded_num_reqs,
) = self._prepare_inputs(scheduler_output)

# multi-modal support
Expand Down Expand Up @@ -756,7 +766,8 @@ def _execute_model(
aux_hidden_states=aux_hidden_states,
spec_decode_metadata=spec_decode_metadata,
kv_connector_output=kv_connector_output,
logits_indices_selector=logits_indices_selector)
logits_indices_selector=logits_indices_selector,
padded_num_reqs=padded_num_reqs)
return attn_metadata, None

def _sample_from_logits(
Expand All @@ -770,11 +781,19 @@ def _sample_from_logits(
spec_decode_metadata: Optional[SpecDecodeMetadata],
kv_connector_output: Optional[KVConnectorOutput],
logits_indices_selector: Optional[List[int]] = None,
padded_num_reqs: Optional[int] = None,
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
self.input_batch.num_reqs, self.max_num_reqs)
if padded_num_reqs is None:
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
self.input_batch.num_reqs, self.max_num_reqs)

sharding = None
if self.dp_size > 1:
sharding = NamedSharding(self.mesh,
PartitionSpec(ShardingAxisName.ATTN_DATA))

tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
self.mesh, self.input_batch, padded_num_reqs)
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
if spec_decode_metadata is None:
next_tokens = sample(
self.rng_params_for_sampling,
Expand Down Expand Up @@ -806,8 +825,6 @@ def _sample_from_logits(
if tpu_sampling_metadata.logprobs:
logprobs = self._compute_and_gather_logprobs(
logits, next_tokens, self.model_config.max_logprobs)
logprobs_lists = jax.tree.map(lambda x: j2t(x.astype(jnp.float32)),
logprobs).tolists()
else:
logprobs = None

Expand Down Expand Up @@ -860,9 +877,8 @@ def _sample_from_logits(

if logprobs is not None:
# Map logprobs back to the pre-dp shuffling order
if logits_indices_selector is not None:
logprobs_lists = _reorder_logits_indices(
logprobs_lists, logits_indices_selector)
logprobs_lists = _jax_logprobs_to_lists(
logprobs, logits_indices_selector)

else:
logprobs_lists = None
Expand Down Expand Up @@ -934,9 +950,8 @@ def _sample_from_logits(

if logprobs is not None:
# Map logprobs back to the pre-dp shuffling order
if logits_indices_selector is not None:
logprobs_lists = _reorder_logits_indices(
logprobs_lists, logits_indices_selector)
logprobs_lists = _jax_logprobs_to_lists(logprobs,
logits_indices_selector)
else:
logprobs_lists = None

Expand Down Expand Up @@ -1397,6 +1412,7 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
logits_indices,
spec_decode_metadata,
logits_indices_selector,
padded_num_reqs,
)

def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
Expand Down Expand Up @@ -1563,7 +1579,8 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
attention_metadata.seq_lens_cpu = seq_lens_cpu
logits_indices_selector = None
return (input_ids, attention_metadata, sampling_metadata,
logits_indices, spec_decode_metadata, logits_indices_selector)
logits_indices, spec_decode_metadata, logits_indices_selector,
padded_num_reqs)

def _get_input_ids_embeds(self, input_ids: jax.Array,
mm_embeds: list[jax.Array]):
Expand Down