diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 1aa0709696c4..18aa599f1aaf 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -212,10 +212,12 @@ def test_update_states_request_resumed(model_runner): # resume req cached_req_data = CachedRequestData( req_ids=[req_id], - resumed_from_preemption=[False], + resumed_req_ids={req_id}, new_token_ids=[[]], + all_token_ids={req_id: scheduler_output.scheduled_new_reqs[0].prompt_token_ids}, new_block_ids=[([],)], num_computed_tokens=[0], + num_output_tokens=[0], ) scheduler_output = SchedulerOutput( diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index c2c34ee95ad5..9007436350be 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -259,10 +259,10 @@ def test_update_states_request_resumed(model_runner, dist_init): # resume req cached_req_data = CachedRequestData( req_ids=[req_id], - resumed_from_preemption=[False], + resumed_req_ids=set(), new_token_ids=[[]], - resumed_req_token_ids=[None], - new_block_ids=([[0]],), + all_token_ids={}, + new_block_ids=[([0],)], num_computed_tokens=[0], num_output_tokens=[0], ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 6d4ffc152de9..19344e5784c2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -494,5 +494,5 @@ def yield_req_data( yield from zip( cached_reqs.req_ids, cached_reqs.new_block_ids, - cached_reqs.resumed_from_preemption, + (req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids), ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index e47cde2614fc..780dd12fccda 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -415,10 +415,10 @@ def build_connector_meta( for i, req_id in enumerate(cached_reqs.req_ids): num_computed_tokens = cached_reqs.num_computed_tokens[i] new_block_ids = cached_reqs.new_block_ids[i] - resumed_from_preemption = cached_reqs.resumed_from_preemption[i] + resumed_from_preemption = req_id in cached_reqs.resumed_req_ids if self.is_producer: - num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_tokens = num_scheduled_tokens + num_computed_tokens assert req_id in self.chunked_prefill assert new_block_ids is not None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index fc277630603a..9c230d7d0d2f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -336,7 +336,7 @@ def build_connector_meta( cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): - resumed_from_preemption = cached_reqs.resumed_from_preemption[i] + resumed_from_preemption = req_id in cached_reqs.resumed_req_ids if not resumed_from_preemption or req_id not in self._requests_need_load: continue diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 035394f04530..cc6b89e2bf3f 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -2,8 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from functools import cached_property from typing import TYPE_CHECKING +from typing_extensions import deprecated + from vllm._bc_linter import bc_linter_include if TYPE_CHECKING: @@ -96,16 +99,16 @@ def anon_repr(self) -> str: @dataclass class CachedRequestData: req_ids: list[str] - # If resumed_from_preemption is False, new_block_ids will be appended to - # the request's block IDs. If True, new_block_ids will be used as the + # For request ids not in resumed_req_ids, new_block_ids will be appended to + # the request's block IDs. For those in the set, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. - resumed_from_preemption: list[bool] + resumed_req_ids: set[str] # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # When PP is not used, new_token_ids will be empty. new_token_ids: list[list[int]] - # If resumed_from_preemption is True, propogate the token ids to the - # connector, otherwise will be empty. - resumed_req_token_ids: list[list[int] | None] + # For requests not scheduled in the last step, propagate the token ids to the + # connector. Won't contain requests that were scheduled in the prior step. + all_token_ids: dict[str, list[int]] new_block_ids: list[tuple[list[int], ...] | None] num_computed_tokens: list[int] num_output_tokens: list[int] @@ -114,13 +117,26 @@ class CachedRequestData: def num_reqs(self) -> int: return len(self.req_ids) + @cached_property + @deprecated("use resumed_req_ids field") + def resumed_from_preemption(self) -> list[bool]: + return [req_id in self.resumed_req_ids for req_id in self.req_ids] + + @cached_property + @deprecated("use all_token_ids field") + def resumed_req_token_ids(self) -> list[list[int] | None]: + return [ + self.all_token_ids[req_id] if req_id in self.resumed_req_ids else None + for req_id in self.req_ids + ] + @classmethod def make_empty(cls) -> "CachedRequestData": return cls( req_ids=[], - resumed_from_preemption=[], + resumed_req_ids=set(), new_token_ids=[], - resumed_req_token_ids=[], + all_token_ids={}, new_block_ids=[], num_computed_tokens=[], num_output_tokens=[], diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 00b34fe4fbb9..c794886bc24c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -71,6 +71,7 @@ def __init__( self.finished_req_ids_dict: dict[int, set[str]] | None = ( defaultdict(set) if include_finished_set else None ) + self.prev_step_scheduled_req_ids: set[str] = set() # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -444,14 +445,9 @@ def schedule(self) -> SchedulerOutput: # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if ( - 0 - < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens - ): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold - ) + threshold = self.scheduler_config.long_prefill_token_threshold + if 0 < threshold < num_new_tokens: + num_new_tokens = threshold # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked @@ -620,6 +616,11 @@ def schedule(self) -> SchedulerOutput: structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( num_scheduled_tokens.keys(), scheduled_spec_decode_tokens ) + + # Record the request ids that were scheduled in this step. + self.prev_step_scheduled_req_ids.clear() + self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -691,14 +692,12 @@ def _make_cached_request_data( req_ids: list[str] = [] new_token_ids: list[list[int]] = [] new_block_ids: list[tuple[list[int], ...] | None] = [] - resumed_req_token_ids: list[list[int] | None] = [] + all_token_ids: dict[str, list[int]] = {} num_computed_tokens: list[int] = [] num_output_tokens: list[int] = [] + resumed_req_ids = set() - # Because resumed_reqs is usually empty, it is more efficient to do - # in-place appending so that we don't need to allocate a new list. - resumed_from_preemption = [False] * len(running_reqs) - resumed_from_preemption += [True] * len(resumed_reqs) + num_running_reqs = len(running_reqs) for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): req_id = req.request_id req_ids.append(req_id) @@ -715,12 +714,14 @@ def _make_cached_request_data( req.num_computed_tokens : req.num_computed_tokens + num_tokens ] new_token_ids.append(token_ids) - resumed_token_ids = None - if resumed_from_preemption[idx]: - resumed_token_ids = req.all_token_ids[ + scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids + if idx >= num_running_reqs: + assert not scheduled_in_prev_step + resumed_req_ids.add(req_id) + if not scheduled_in_prev_step: + all_token_ids[req_id] = req.all_token_ids[ : req.num_computed_tokens + num_tokens ] - resumed_req_token_ids.append(resumed_token_ids) new_block_ids.append( req_to_new_blocks[req_id].get_block_ids(allow_none=True) ) @@ -731,9 +732,9 @@ def _make_cached_request_data( return CachedRequestData( req_ids=req_ids, - resumed_from_preemption=resumed_from_preemption, + resumed_req_ids=resumed_req_ids, new_token_ids=new_token_ids, - resumed_req_token_ids=resumed_req_token_ids, + all_token_ids=all_token_ids, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, num_output_tokens=num_output_tokens, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e350988456f1..1fe749c614cc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -706,7 +706,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] + resumed_from_preemption = req_id in req_data.resumed_req_ids num_output_tokens = req_data.num_output_tokens[i] # Update the cached states. @@ -754,16 +754,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids - if self.use_async_scheduling and num_output_tokens > 0: - # We must recover the output token ids for resumed requests in the - # async scheduling case, so that correct input_ids are obtained. - resumed_token_ids = req_data.resumed_req_token_ids[i] - assert resumed_token_ids is not None - req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. + + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests in the + # async scheduling case, so that correct input_ids are obtained. + resumed_token_ids = req_data.all_token_ids[req_id] + req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] + reqs_to_add.append(req_state) continue diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5d7b181989ce..0ced138b940d 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -483,7 +483,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] + resumed_from_preemption = req_id in req_data.resumed_req_ids # Update the cached states. req_state.num_computed_tokens = num_computed_tokens