Skip to content

Commit cb338b9

Browse files
njhillMatthewBonanni
authored andcommitted
[BugFix] Handle unscheduled requests properly when async scheduling (vllm-project#27756)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent d1ebffb commit cb338b9

File tree

9 files changed

+63
-43
lines changed

9 files changed

+63
-43
lines changed

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,12 @@ def test_update_states_request_resumed(model_runner):
212212
# resume req
213213
cached_req_data = CachedRequestData(
214214
req_ids=[req_id],
215-
resumed_from_preemption=[False],
215+
resumed_req_ids={req_id},
216216
new_token_ids=[[]],
217+
all_token_ids={req_id: scheduler_output.scheduled_new_reqs[0].prompt_token_ids},
217218
new_block_ids=[([],)],
218219
num_computed_tokens=[0],
220+
num_output_tokens=[0],
219221
)
220222

221223
scheduler_output = SchedulerOutput(

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
259259
# resume req
260260
cached_req_data = CachedRequestData(
261261
req_ids=[req_id],
262-
resumed_from_preemption=[False],
262+
resumed_req_ids=set(),
263263
new_token_ids=[[]],
264-
resumed_req_token_ids=[None],
265-
new_block_ids=([[0]],),
264+
all_token_ids={},
265+
new_block_ids=[([0],)],
266266
num_computed_tokens=[0],
267267
num_output_tokens=[0],
268268
)

vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,5 +494,5 @@ def yield_req_data(
494494
yield from zip(
495495
cached_reqs.req_ids,
496496
cached_reqs.new_block_ids,
497-
cached_reqs.resumed_from_preemption,
497+
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
498498
)

vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,10 @@ def build_connector_meta(
415415
for i, req_id in enumerate(cached_reqs.req_ids):
416416
num_computed_tokens = cached_reqs.num_computed_tokens[i]
417417
new_block_ids = cached_reqs.new_block_ids[i]
418-
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
418+
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
419419

420420
if self.is_producer:
421-
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id]
421+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
422422
num_tokens = num_scheduled_tokens + num_computed_tokens
423423
assert req_id in self.chunked_prefill
424424
assert new_block_ids is not None

vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def build_connector_meta(
336336

337337
cached_reqs = scheduler_output.scheduled_cached_reqs
338338
for i, req_id in enumerate(cached_reqs.req_ids):
339-
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
339+
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
340340
if not resumed_from_preemption or req_id not in self._requests_need_load:
341341
continue
342342

vllm/v1/core/sched/output.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from dataclasses import dataclass
5+
from functools import cached_property
56
from typing import TYPE_CHECKING
67

8+
from typing_extensions import deprecated
9+
710
from vllm._bc_linter import bc_linter_include
811

912
if TYPE_CHECKING:
@@ -96,16 +99,16 @@ def anon_repr(self) -> str:
9699
@dataclass
97100
class CachedRequestData:
98101
req_ids: list[str]
99-
# If resumed_from_preemption is False, new_block_ids will be appended to
100-
# the request's block IDs. If True, new_block_ids will be used as the
102+
# For request ids not in resumed_req_ids, new_block_ids will be appended to
103+
# the request's block IDs. For those in the set, new_block_ids will be used as the
101104
# request's block IDs instead of appending to the existing block IDs.
102-
resumed_from_preemption: list[bool]
105+
resumed_req_ids: set[str]
103106
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
104107
# When PP is not used, new_token_ids will be empty.
105108
new_token_ids: list[list[int]]
106-
# If resumed_from_preemption is True, propogate the token ids to the
107-
# connector, otherwise will be empty.
108-
resumed_req_token_ids: list[list[int] | None]
109+
# For requests not scheduled in the last step, propagate the token ids to the
110+
# connector. Won't contain requests that were scheduled in the prior step.
111+
all_token_ids: dict[str, list[int]]
109112
new_block_ids: list[tuple[list[int], ...] | None]
110113
num_computed_tokens: list[int]
111114
num_output_tokens: list[int]
@@ -114,13 +117,26 @@ class CachedRequestData:
114117
def num_reqs(self) -> int:
115118
return len(self.req_ids)
116119

120+
@cached_property
121+
@deprecated("use resumed_req_ids field")
122+
def resumed_from_preemption(self) -> list[bool]:
123+
return [req_id in self.resumed_req_ids for req_id in self.req_ids]
124+
125+
@cached_property
126+
@deprecated("use all_token_ids field")
127+
def resumed_req_token_ids(self) -> list[list[int] | None]:
128+
return [
129+
self.all_token_ids[req_id] if req_id in self.resumed_req_ids else None
130+
for req_id in self.req_ids
131+
]
132+
117133
@classmethod
118134
def make_empty(cls) -> "CachedRequestData":
119135
return cls(
120136
req_ids=[],
121-
resumed_from_preemption=[],
137+
resumed_req_ids=set(),
122138
new_token_ids=[],
123-
resumed_req_token_ids=[],
139+
all_token_ids={},
124140
new_block_ids=[],
125141
num_computed_tokens=[],
126142
num_output_tokens=[],

vllm/v1/core/sched/scheduler.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
self.finished_req_ids_dict: dict[int, set[str]] | None = (
7272
defaultdict(set) if include_finished_set else None
7373
)
74+
self.prev_step_scheduled_req_ids: set[str] = set()
7475

7576
# Scheduling constraints.
7677
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
@@ -444,14 +445,9 @@ def schedule(self) -> SchedulerOutput:
444445
# `request.num_prompt_tokens` to consider the resumed
445446
# requests, which have output tokens.
446447
num_new_tokens = request.num_tokens - num_computed_tokens
447-
if (
448-
0
449-
< self.scheduler_config.long_prefill_token_threshold
450-
< num_new_tokens
451-
):
452-
num_new_tokens = (
453-
self.scheduler_config.long_prefill_token_threshold
454-
)
448+
threshold = self.scheduler_config.long_prefill_token_threshold
449+
if 0 < threshold < num_new_tokens:
450+
num_new_tokens = threshold
455451

456452
# chunked prefill has to be enabled explicitly to allow
457453
# pooling requests to be chunked
@@ -620,6 +616,11 @@ def schedule(self) -> SchedulerOutput:
620616
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
621617
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
622618
)
619+
620+
# Record the request ids that were scheduled in this step.
621+
self.prev_step_scheduled_req_ids.clear()
622+
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
623+
623624
scheduler_output = SchedulerOutput(
624625
scheduled_new_reqs=new_reqs_data,
625626
scheduled_cached_reqs=cached_reqs_data,
@@ -691,14 +692,12 @@ def _make_cached_request_data(
691692
req_ids: list[str] = []
692693
new_token_ids: list[list[int]] = []
693694
new_block_ids: list[tuple[list[int], ...] | None] = []
694-
resumed_req_token_ids: list[list[int] | None] = []
695+
all_token_ids: dict[str, list[int]] = {}
695696
num_computed_tokens: list[int] = []
696697
num_output_tokens: list[int] = []
698+
resumed_req_ids = set()
697699

698-
# Because resumed_reqs is usually empty, it is more efficient to do
699-
# in-place appending so that we don't need to allocate a new list.
700-
resumed_from_preemption = [False] * len(running_reqs)
701-
resumed_from_preemption += [True] * len(resumed_reqs)
700+
num_running_reqs = len(running_reqs)
702701
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
703702
req_id = req.request_id
704703
req_ids.append(req_id)
@@ -715,12 +714,14 @@ def _make_cached_request_data(
715714
req.num_computed_tokens : req.num_computed_tokens + num_tokens
716715
]
717716
new_token_ids.append(token_ids)
718-
resumed_token_ids = None
719-
if resumed_from_preemption[idx]:
720-
resumed_token_ids = req.all_token_ids[
717+
scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids
718+
if idx >= num_running_reqs:
719+
assert not scheduled_in_prev_step
720+
resumed_req_ids.add(req_id)
721+
if not scheduled_in_prev_step:
722+
all_token_ids[req_id] = req.all_token_ids[
721723
: req.num_computed_tokens + num_tokens
722724
]
723-
resumed_req_token_ids.append(resumed_token_ids)
724725
new_block_ids.append(
725726
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
726727
)
@@ -731,9 +732,9 @@ def _make_cached_request_data(
731732

732733
return CachedRequestData(
733734
req_ids=req_ids,
734-
resumed_from_preemption=resumed_from_preemption,
735+
resumed_req_ids=resumed_req_ids,
735736
new_token_ids=new_token_ids,
736-
resumed_req_token_ids=resumed_req_token_ids,
737+
all_token_ids=all_token_ids,
737738
new_block_ids=new_block_ids,
738739
num_computed_tokens=num_computed_tokens,
739740
num_output_tokens=num_output_tokens,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
706706
req_state = self.requests[req_id]
707707
num_computed_tokens = req_data.num_computed_tokens[i]
708708
new_block_ids = req_data.new_block_ids[i]
709-
resumed_from_preemption = req_data.resumed_from_preemption[i]
709+
resumed_from_preemption = req_id in req_data.resumed_req_ids
710710
num_output_tokens = req_data.num_output_tokens[i]
711711

712712
# Update the cached states.
@@ -754,16 +754,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
754754
# Replace the existing block IDs with the new ones.
755755
req_state.block_ids = new_block_ids
756756

757-
if self.use_async_scheduling and num_output_tokens > 0:
758-
# We must recover the output token ids for resumed requests in the
759-
# async scheduling case, so that correct input_ids are obtained.
760-
resumed_token_ids = req_data.resumed_req_token_ids[i]
761-
assert resumed_token_ids is not None
762-
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
763757
if req_index is None:
764758
# The request is not in the persistent batch.
765759
# The request was either preempted and resumed later, or was not
766760
# scheduled in the previous step and needs to be added again.
761+
762+
if self.use_async_scheduling and num_output_tokens > 0:
763+
# We must recover the output token ids for resumed requests in the
764+
# async scheduling case, so that correct input_ids are obtained.
765+
resumed_token_ids = req_data.all_token_ids[req_id]
766+
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
767+
767768
reqs_to_add.append(req_state)
768769
continue
769770

vllm/v1/worker/tpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
483483
req_state = self.requests[req_id]
484484
num_computed_tokens = req_data.num_computed_tokens[i]
485485
new_block_ids = req_data.new_block_ids[i]
486-
resumed_from_preemption = req_data.resumed_from_preemption[i]
486+
resumed_from_preemption = req_id in req_data.resumed_req_ids
487487

488488
# Update the cached states.
489489
req_state.num_computed_tokens = num_computed_tokens

0 commit comments

Comments
 (0)