Skip to content

Commit 8ba5cc9

Browse files
committed
fix unit tests
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent f7f2b52 commit 8ba5cc9

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

tests/worker/tpu_worker_test.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def mock_vllm_config():
2525
mock_parallel_conf = MagicMock()
2626
mock_parallel_conf.tensor_parallel_size = 2
2727
mock_parallel_conf.data_parallel_size = 1
28+
mock_parallel_conf.pipeline_parallel_size = 1
2829
mock_parallel_conf.nnodes = 1
2930
mock_parallel_conf.nnodes_within_dp = 1
3031

@@ -118,8 +119,14 @@ def test_init_device_with_provided_devices(
118119

119120
worker.init_device()
120121

121-
mock_jax.devices.assert_not_called()
122-
mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices)
122+
mock_jax.local_devices.assert_not_called()
123+
expected_rank = 0
124+
expected_is_first_rank = True
125+
expected_is_last_rank = True
126+
mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices,
127+
expected_rank,
128+
expected_is_first_rank,
129+
expected_is_last_rank)
123130
assert isinstance(worker.model_runner, MagicMock)
124131

125132
@patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
@@ -137,15 +144,24 @@ def test_init_device_autodetects_devices(
137144
distributed_init_method="test_method",
138145
devices=[] # No devices provided, should trigger auto-detection
139146
)
140-
mock_jax.devices.return_value = ['tpu:0', 'tpu:1', 'tpu:2', 'tpu:3']
147+
mock_jax.local_device_count.return_value = 4
148+
mock_jax.local_devices.return_value = [
149+
'tpu:0', 'tpu:1', 'tpu:2', 'tpu:3'
150+
]
141151

142152
worker.init_device()
143153

144-
mock_jax.devices.assert_called_once()
154+
mock_jax.local_devices.assert_called_once()
145155
expected_devices = ['tpu:0', 'tpu:1'] # Sliced by tensor_parallel_size
146156
assert worker.devices == expected_devices
157+
expected_rank = 0
158+
expected_is_first_rank = True
159+
expected_is_last_rank = True
147160
mock_runner_cls.assert_called_once_with(mock_vllm_config,
148-
expected_devices)
161+
expected_devices,
162+
expected_rank,
163+
expected_is_first_rank,
164+
expected_is_last_rank)
149165

150166
@patch('tpu_inference.worker.tpu_worker.utils')
151167
def test_determine_available_memory(self, mock_utils, mock_vllm_config):
@@ -194,7 +210,7 @@ def test_execute_model(self, mock_runner_cls, mock_vllm_config):
194210

195211
# Assert the runner was called with the scheduler output directly
196212
worker.model_runner.execute_model.assert_called_once_with(
197-
mock_scheduler_input)
213+
mock_scheduler_input, None)
198214
# Assert the final result is the concrete model output
199215
assert result == mock_model_output
200216

tpu_inference/runner/tpu_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ def __init__(
223223
self,
224224
vllm_config: VllmConfig,
225225
devices: List[Any],
226+
rank: int = 0,
227+
is_first_rank: bool = True,
228+
is_last_rank: bool = True,
226229
):
227230
self.vllm_config = vllm_config
228231
self.model_config = vllm_config.model_config

tpu_inference/worker/tpu_worker.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(
117117
# TPU Worker is initialized. The profiler server needs to start after
118118
# MP runtime is initialized.
119119
self.profile_dir = None
120-
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_world_size == 1:
120+
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
121121
if not self.devices or 0 in self.device_ranks:
122122
# For TPU, we can only have 1 active profiler session for 1 profiler
123123
# server. So we only profile on rank0.
@@ -126,9 +126,9 @@ def __init__(
126126
self.profile_dir)
127127

128128
# For PP, we use MPMD so we want to profile every worker.
129-
if self.pp_world_size > 1 and envs.VLLM_TORCH_PROFILER_DIR:
129+
if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
130130
self.profile_dir = os.path.join(
131-
envs.VLLM_TORCH_PROFILER_DIR,
131+
vllm_envs.VLLM_TORCH_PROFILER_DIR,
132132
f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
133133
)
134134
os.makedirs(self.profile_dir, exist_ok=True)
@@ -161,7 +161,7 @@ def init_device(self,
161161
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
162162
tpu_ports = [
163163
jax_parallel_state.BASE_JAX_PORT + i
164-
for i in range(self.pp_world_size)
164+
for i in range(self.pp_config.pp_world_size)
165165
]
166166
os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
167167
[f"localhost:{port}" for port in tpu_ports])
@@ -206,7 +206,7 @@ def init_device(self,
206206
if device is None:
207207
raise KeyError(
208208
f"Device index {device_index} not found in "
209-
f"jax.devices() with IDs {list(device_dict.keys())}!"
209+
f"jax.local_devices() with IDs {list(device_dict.keys())}!"
210210
)
211211
self.devices.append(device)
212212
assert len(self.devices) >= sharding_config.total_devices
@@ -240,9 +240,9 @@ def init_device(self,
240240
need_pp=self.parallel_config.pipeline_parallel_size > 1)
241241

242242
ensure_kv_transfer_initialized(self.vllm_config)
243-
self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
244-
self.rank, self.rank == 0,
245-
self.rank == self.pp_world_size - 1)
243+
self.model_runner = TPUModelRunner(
244+
self.vllm_config, self.devices, self.rank, self.rank == 0,
245+
self.rank == self.pp_config.pp_world_size - 1)
246246
logger.info(f"Init worker | "
247247
f"rank={self.rank} | "
248248
f"node_id={get_node_id()} | "

0 commit comments

Comments
 (0)