Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,12 @@ def test_update_states_request_resumed(model_runner):
# resume req
cached_req_data = CachedRequestData(
req_ids=[req_id],
resumed_from_preemption=[False],
resumed_req_ids={req_id},
new_token_ids=[[]],
all_token_ids={req_id: scheduler_output.scheduled_new_reqs[0].prompt_token_ids},
new_block_ids=[([],)],
num_computed_tokens=[0],
num_output_tokens=[0],
)

scheduler_output = SchedulerOutput(
Expand Down
6 changes: 3 additions & 3 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
# resume req
cached_req_data = CachedRequestData(
req_ids=[req_id],
resumed_from_preemption=[False],
resumed_req_ids=set(),
new_token_ids=[[]],
resumed_req_token_ids=[None],
new_block_ids=([[0]],),
all_token_ids={},
new_block_ids=[([0],)],
num_computed_tokens=[0],
num_output_tokens=[0],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,5 +494,5 @@ def yield_req_data(
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
cached_reqs.resumed_from_preemption,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,10 @@ def build_connector_meta(
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids

if self.is_producer:
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_tokens = num_scheduled_tokens + num_computed_tokens
assert req_id in self.chunked_prefill
assert new_block_ids is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def build_connector_meta(

cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
if not resumed_from_preemption or req_id not in self._requests_need_load:
continue

Expand Down
32 changes: 24 additions & 8 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING

from typing_extensions import deprecated

from vllm._bc_linter import bc_linter_include

if TYPE_CHECKING:
Expand Down Expand Up @@ -94,18 +97,18 @@

@bc_linter_include
@dataclass
class CachedRequestData:

Check notice on line 100 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function CachedRequestData: resumed_req_token_ids was removed

Check notice on line 100 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function CachedRequestData: resumed_from_preemption was removed

Check notice on line 100 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function CachedRequestData: resumed_req_token_ids was removed

Check notice on line 100 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function CachedRequestData: resumed_from_preemption was removed
req_ids: list[str]
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# For request ids not in resumed_req_ids, new_block_ids will be appended to
# the request's block IDs. For those in the set, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: list[bool]
resumed_req_ids: set[str]

Check notice on line 105 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function CachedRequestData: resumed_req_ids was added

Check notice on line 105 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function CachedRequestData: resumed_req_ids was added
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this to a set since these will be rare and we currently are creating a [None] * batch_size list every time.

# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]]
# If resumed_from_preemption is True, propogate the token ids to the
# connector, otherwise will be empty.
resumed_req_token_ids: list[list[int] | None]
# For requests not scheduled in the last step, propagate the token ids to the
# connector. Won't contain requests that were scheduled in the prior step.
all_token_ids: dict[str, list[int]]

Check notice on line 111 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function CachedRequestData: all_token_ids was added

Check notice on line 111 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function CachedRequestData: all_token_ids was added
new_block_ids: list[tuple[list[int], ...] | None]
num_computed_tokens: list[int]
num_output_tokens: list[int]
Expand All @@ -114,13 +117,26 @@
def num_reqs(self) -> int:
return len(self.req_ids)

@cached_property
@deprecated("use resumed_req_ids field")
def resumed_from_preemption(self) -> list[bool]:
return [req_id in self.resumed_req_ids for req_id in self.req_ids]

@cached_property
@deprecated("use all_token_ids field")
def resumed_req_token_ids(self) -> list[list[int] | None]:
return [
self.all_token_ids[req_id] if req_id in self.resumed_req_ids else None
for req_id in self.req_ids
]
Comment on lines +120 to +131
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are for backwards compatibility.


@classmethod
def make_empty(cls) -> "CachedRequestData":
return cls(
req_ids=[],
resumed_from_preemption=[],
resumed_req_ids=set(),
new_token_ids=[],
resumed_req_token_ids=[],
all_token_ids={},
new_block_ids=[],
num_computed_tokens=[],
num_output_tokens=[],
Expand Down
39 changes: 20 additions & 19 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
self.finished_req_ids_dict: dict[int, set[str]] | None = (
defaultdict(set) if include_finished_set else None
)
self.prev_step_scheduled_req_ids: set[str] = set()

# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
Expand Down Expand Up @@ -444,14 +445,9 @@ def schedule(self) -> SchedulerOutput:
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (
0
< self.scheduler_config.long_prefill_token_threshold
< num_new_tokens
):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold
)
threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens:
num_new_tokens = threshold
Comment on lines +448 to +450
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated simplification, hurt me to look at that formatting :)


# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
Expand Down Expand Up @@ -620,6 +616,11 @@ def schedule(self) -> SchedulerOutput:
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
)

# Record the request ids that were scheduled in this step.
self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())

scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
Expand Down Expand Up @@ -691,14 +692,12 @@ def _make_cached_request_data(
req_ids: list[str] = []
new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...] | None] = []
resumed_req_token_ids: list[list[int] | None] = []
all_token_ids: dict[str, list[int]] = {}
num_computed_tokens: list[int] = []
num_output_tokens: list[int] = []
resumed_req_ids = set()

# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
resumed_from_preemption += [True] * len(resumed_reqs)
num_running_reqs = len(running_reqs)
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
req_id = req.request_id
req_ids.append(req_id)
Expand All @@ -715,12 +714,14 @@ def _make_cached_request_data(
req.num_computed_tokens : req.num_computed_tokens + num_tokens
]
new_token_ids.append(token_ids)
resumed_token_ids = None
if resumed_from_preemption[idx]:
resumed_token_ids = req.all_token_ids[
scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids
if idx >= num_running_reqs:
assert not scheduled_in_prev_step
resumed_req_ids.add(req_id)
if not scheduled_in_prev_step:
all_token_ids[req_id] = req.all_token_ids[
: req.num_computed_tokens + num_tokens
]
resumed_req_token_ids.append(resumed_token_ids)
new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
)
Expand All @@ -731,9 +732,9 @@ def _make_cached_request_data(

return CachedRequestData(
req_ids=req_ids,
resumed_from_preemption=resumed_from_preemption,
resumed_req_ids=resumed_req_ids,
new_token_ids=new_token_ids,
resumed_req_token_ids=resumed_req_token_ids,
all_token_ids=all_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
num_output_tokens=num_output_tokens,
Expand Down
15 changes: 8 additions & 7 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
resumed_from_preemption = req_id in req_data.resumed_req_ids
num_output_tokens = req_data.num_output_tokens[i]

# Update the cached states.
Expand Down Expand Up @@ -754,16 +754,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids

if self.use_async_scheduling and num_output_tokens > 0:
# We must recover the output token ids for resumed requests in the
# async scheduling case, so that correct input_ids are obtained.
resumed_token_ids = req_data.resumed_req_token_ids[i]
assert resumed_token_ids is not None
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.

if self.use_async_scheduling and num_output_tokens > 0:
# We must recover the output token ids for resumed requests in the
# async scheduling case, so that correct input_ids are obtained.
resumed_token_ids = req_data.all_token_ids[req_id]
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]

reqs_to_add.append(req_state)
continue

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
resumed_from_preemption = req_id in req_data.resumed_req_ids

# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
Expand Down