- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 11k
 
[BugFix] Handle unscheduled requests properly when async scheduling #27756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Nick Hill <nhill@redhat.com>
| threshold = self.scheduler_config.long_prefill_token_threshold | ||
| if 0 < threshold < num_new_tokens: | ||
| num_new_tokens = threshold | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated simplification, hurt me to look at that formatting :)
| # the request's block IDs. For those in the set, new_block_ids will be used as the | ||
| # request's block IDs instead of appending to the existing block IDs. | ||
| resumed_from_preemption: list[bool] | ||
| resumed_req_ids: set[str] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing this to a set since these will be rare and we currently are creating a [None] * batch_size list every time.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
| 
           This pull request has merge conflicts that must be resolved before it can be  | 
    
# Conflicts: # vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
| @cached_property | ||
| @deprecated("use resumed_req_ids field") | ||
| def resumed_from_preemption(self) -> list[bool]: | ||
| return [req_id in self.resumed_req_ids for req_id in self.req_ids] | ||
| 
               | 
          ||
| @cached_property | ||
| @deprecated("use all_token_ids field") | ||
| def resumed_req_token_ids(self) -> list[list[int] | None]: | ||
| return [ | ||
| self.all_token_ids[req_id] if req_id in self.resumed_req_ids else None | ||
| for req_id in self.req_ids | ||
| ] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are for backwards compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, discussed online. I would like to see a test case for the new coverage, can go ahead and merge at any point.
| 
           For follow-on: 
  | 
    
…llm-project#27756) Signed-off-by: Nick Hill <nhill@redhat.com>
…equests properly when async scheduling #27756 (#507) Culprit commit: vllm-project/vllm#27756 --------- Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: Michał Kuligowski <michal.kuligowski@intel.com> Signed-off-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: Michał Kuligowski <michal.kuligowski@intel.com>
There may be circumstances other then preemption where a running request is temporarily not scheduled in the batch for some step(s). These need to be handled similarly to the async scheduling + preemption fix made in #26385.
This PR also streamlines how resumed requests are recorded in
CachedRequestData.