@@ -71,6 +71,7 @@ def __init__(
7171 self .finished_req_ids_dict : dict [int , set [str ]] | None = (
7272 defaultdict (set ) if include_finished_set else None
7373 )
74+ self .prev_step_scheduled_req_ids : set [str ] = set ()
7475
7576 # Scheduling constraints.
7677 self .max_num_running_reqs = self .scheduler_config .max_num_seqs
@@ -444,14 +445,9 @@ def schedule(self) -> SchedulerOutput:
444445 # `request.num_prompt_tokens` to consider the resumed
445446 # requests, which have output tokens.
446447 num_new_tokens = request .num_tokens - num_computed_tokens
447- if (
448- 0
449- < self .scheduler_config .long_prefill_token_threshold
450- < num_new_tokens
451- ):
452- num_new_tokens = (
453- self .scheduler_config .long_prefill_token_threshold
454- )
448+ threshold = self .scheduler_config .long_prefill_token_threshold
449+ if 0 < threshold < num_new_tokens :
450+ num_new_tokens = threshold
455451
456452 # chunked prefill has to be enabled explicitly to allow
457453 # pooling requests to be chunked
@@ -620,6 +616,11 @@ def schedule(self) -> SchedulerOutput:
620616 structured_output_request_ids , grammar_bitmask = self .get_grammar_bitmask (
621617 num_scheduled_tokens .keys (), scheduled_spec_decode_tokens
622618 )
619+
620+ # Record the request ids that were scheduled in this step.
621+ self .prev_step_scheduled_req_ids .clear ()
622+ self .prev_step_scheduled_req_ids .update (num_scheduled_tokens .keys ())
623+
623624 scheduler_output = SchedulerOutput (
624625 scheduled_new_reqs = new_reqs_data ,
625626 scheduled_cached_reqs = cached_reqs_data ,
@@ -691,14 +692,12 @@ def _make_cached_request_data(
691692 req_ids : list [str ] = []
692693 new_token_ids : list [list [int ]] = []
693694 new_block_ids : list [tuple [list [int ], ...] | None ] = []
694- resumed_req_token_ids : list [ list [int ] | None ] = []
695+ all_token_ids : dict [ str , list [int ]] = {}
695696 num_computed_tokens : list [int ] = []
696697 num_output_tokens : list [int ] = []
698+ resumed_req_ids = set ()
697699
698- # Because resumed_reqs is usually empty, it is more efficient to do
699- # in-place appending so that we don't need to allocate a new list.
700- resumed_from_preemption = [False ] * len (running_reqs )
701- resumed_from_preemption += [True ] * len (resumed_reqs )
700+ num_running_reqs = len (running_reqs )
702701 for idx , req in enumerate (itertools .chain (running_reqs , resumed_reqs )):
703702 req_id = req .request_id
704703 req_ids .append (req_id )
@@ -715,12 +714,14 @@ def _make_cached_request_data(
715714 req .num_computed_tokens : req .num_computed_tokens + num_tokens
716715 ]
717716 new_token_ids .append (token_ids )
718- resumed_token_ids = None
719- if resumed_from_preemption [idx ]:
720- resumed_token_ids = req .all_token_ids [
717+ scheduled_in_prev_step = req_id in self .prev_step_scheduled_req_ids
718+ if idx >= num_running_reqs :
719+ assert not scheduled_in_prev_step
720+ resumed_req_ids .add (req_id )
721+ if not scheduled_in_prev_step :
722+ all_token_ids [req_id ] = req .all_token_ids [
721723 : req .num_computed_tokens + num_tokens
722724 ]
723- resumed_req_token_ids .append (resumed_token_ids )
724725 new_block_ids .append (
725726 req_to_new_blocks [req_id ].get_block_ids (allow_none = True )
726727 )
@@ -731,9 +732,9 @@ def _make_cached_request_data(
731732
732733 return CachedRequestData (
733734 req_ids = req_ids ,
734- resumed_from_preemption = resumed_from_preemption ,
735+ resumed_req_ids = resumed_req_ids ,
735736 new_token_ids = new_token_ids ,
736- resumed_req_token_ids = resumed_req_token_ids ,
737+ all_token_ids = all_token_ids ,
737738 new_block_ids = new_block_ids ,
738739 num_computed_tokens = num_computed_tokens ,
739740 num_output_tokens = num_output_tokens ,
0 commit comments