From 136a515f379df15b808d9989759516e5ecbaf8ef Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 27 Nov 2025 00:02:33 -0800 Subject: [PATCH 1/5] pp pd decode wip --- python/sglang/srt/disaggregation/decode.py | 51 ++- .../sglang/srt/managers/scheduler_pp_mixin.py | 337 +++++++++++++++++- 2 files changed, 373 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 038a6dc21ec..3b37c427d6a 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -25,7 +25,7 @@ from collections import deque from dataclasses import dataclass from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Type, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union import torch from torch.distributed import ProcessGroup @@ -195,6 +195,7 @@ def __init__( bootstrap_port: int, max_total_num_tokens: int, prefill_pp_size: int, + pp_rank: int, num_reserved_decode_tokens: int, transfer_backend: TransferBackend, ): @@ -216,6 +217,7 @@ def __init__( self.bootstrap_port = bootstrap_port self.max_total_num_tokens = max_total_num_tokens self.prefill_pp_size = prefill_pp_size + self.pp_rank = pp_rank self.num_reserved_decode_tokens = num_reserved_decode_tokens self.transfer_backend = transfer_backend # Queue for requests pending pre-allocation @@ -233,7 +235,7 @@ def _init_kv_manager(self) -> BaseKVManager: kv_args.decode_tp_size = attn_tp_size # Note(shangming): pp is not supported on the decode side yet, so its rank is fixed to 0 - kv_args.pp_rank = 0 + kv_args.pp_rank = self.pp_rank kv_args.system_dp_rank = self.scheduler.dp_rank kv_args.prefill_pp_size = self.prefill_pp_size kv_data_ptrs, kv_data_lens, kv_item_lens = ( @@ -336,7 +338,9 @@ def extend(self, reqs: List[Req], is_retracted: bool = False) -> None: for req in reqs: self.add(req, is_retracted=is_retracted) - def resume_retracted_reqs(self) -> List[Req]: + def resume_retracted_reqs( + self, rids_to_check: Optional[List[str]] = None + ) -> Tuple[List[Req], bool]: # TODO refactor the scheduling part, reuse with the unified engine logic as much as possible # allocate memory @@ -345,6 +349,9 @@ def resume_retracted_reqs(self) -> List[Req]: allocatable_tokens = self._allocatable_tokens(count_retracted=False) for i, req in enumerate(self.retracted_queue): + if rids_to_check is not None and req.rid not in rids_to_check: + continue + if self.req_to_token_pool.available_size() <= 0: break @@ -371,9 +378,17 @@ def resume_retracted_reqs(self) -> List[Req]: if i not in indices_to_remove ] - return resumed_reqs + has_retracted_req = ( + bool(len(self.retracted_queue) > 0) + if rids_to_check is None + else any(req.rid in rids_to_check for req in self.retracted_queue) + ) - def _update_handshake_waiters(self) -> None: + return resumed_reqs, has_retracted_req + + def _update_handshake_waiters( + self, rids_to_check: Optional[List[str]] = None + ) -> None: if not self.queue: return @@ -385,6 +400,9 @@ def _update_handshake_waiters(self) -> None: ) for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): + if rids_to_check is not None and decode_req.rid not in rids_to_check: + continue + if poll == KVPoll.Bootstrapping: pass elif poll == KVPoll.WaitingForInput: @@ -406,10 +424,13 @@ def _update_handshake_waiters(self) -> None: else: raise ValueError(f"Unexpected poll case: {poll}") - def pop_preallocated(self) -> List[DecodeRequest]: + def pop_preallocated( + self, rids_to_check: Optional[List[str]] = None + ) -> List[DecodeRequest]: """Pop the preallocated requests from the pending queue (FIFO).""" - self._update_handshake_waiters() + self._update_handshake_waiters(rids_to_check) + failed_reqs = [] preallocated_reqs = [] indices_to_remove = set() @@ -424,14 +445,20 @@ def pop_preallocated(self) -> List[DecodeRequest]: ) # First, remove all failed requests from the queue for i, decode_req in enumerate(self.queue): + if rids_to_check is not None and decode_req.rid not in rids_to_check: + continue if isinstance(decode_req.req.finished_reason, FINISH_ABORT): self.scheduler.stream_output( [decode_req.req], decode_req.req.return_logprob ) + failed_reqs.append(decode_req) indices_to_remove.add(i) # Then, preallocate the remaining requests if possible for i, decode_req in enumerate(self.queue): + if rids_to_check is not None and decode_req.rid not in rids_to_check: + continue + if i in indices_to_remove: continue @@ -540,7 +567,7 @@ def pop_preallocated(self) -> List[DecodeRequest]: entry for i, entry in enumerate(self.queue) if i not in indices_to_remove ] - return preallocated_reqs + return preallocated_reqs, failed_reqs @property def num_tokens_pre_allocated(self): @@ -722,7 +749,7 @@ def _commit_transfer_to_req(self, decode_req: DecodeRequest) -> None: ) decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter() - def pop_transferred(self) -> List[Req]: + def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req]: if not self.queue: return [] polls = poll_and_all_reduce( @@ -732,6 +759,8 @@ def pop_transferred(self) -> List[Req]: transferred_reqs = [] indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): + if rids_to_check is not None and decode_req.rid not in rids_to_check: + continue if poll == KVPoll.Failed: error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" try: @@ -939,7 +968,7 @@ def process_decode_queue(self: Scheduler): self.decode_offload_manager.check_offload_progress() # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps - resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs() + resumed_reqs, _ = self.disagg_decode_prealloc_queue.resume_retracted_reqs() self.waiting_queue.extend(resumed_reqs) if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0: # if there are still retracted requests, we do not allocate new requests @@ -954,7 +983,7 @@ def process_decode_queue(self: Scheduler): self.polling_count = (self.polling_count + 1) % self.polling_interval if self.polling_count % self.polling_interval == 0: - req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() + req_conns, _ = self.disagg_decode_prealloc_queue.pop_preallocated() self.disagg_decode_transfer_queue.extend(req_conns) alloc_reqs = ( self.disagg_decode_transfer_queue.pop_transferred() diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index bef6c5972bd..8049125119b 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -579,12 +579,14 @@ def _pp_launch_batch( ) return result, event - def get_rids(self: Scheduler, req_queue: List[Req], *poll_statuses_group): + def get_rids( + self: Scheduler, req_queue: List[Req], sender: bool, *poll_statuses_group + ): """ Used by PP, get the required rids with the given poll statuses. """ polls = poll_and_all_reduce( - [req.disagg_kv_sender for req in req_queue], + [req.disagg_kv_sender if sender else req.kv_receiver for req in req_queue], self.tp_worker.get_attention_tp_cpu_group(), ) rids: List = [] @@ -751,6 +753,7 @@ def _pp_pd_get_bootstrapped_ids(self: Scheduler): # First rank, pop the bootstrap reqs from the bootstrap queue good_bootstrapped_rids, bad_bootstrapped_rids = self.get_rids( self.disagg_prefill_bootstrap_queue.queue, + True, [KVPoll.WaitingForInput], [KVPoll.Failed], ) @@ -762,6 +765,7 @@ def _pp_pd_get_bootstrapped_ids(self: Scheduler): ) curr_good_bootstrapped_rids, curr_bad_bootstrapped_rids = self.get_rids( self.disagg_prefill_bootstrap_queue.queue, + True, [KVPoll.WaitingForInput], [KVPoll.Failed], ) @@ -773,11 +777,12 @@ def _pp_pd_get_bootstrapped_ids(self: Scheduler): ) return [good_bootstrapped_rids, bad_bootstrapped_rids] - def _pp_pd_get_transferred_ids(self: Scheduler): + def _pp_pd_get_prefill_transferred_ids(self: Scheduler): # get the current stage transfer success if self.pp_group.is_first_rank: transferred_rids = self.get_rids( self.disagg_prefill_inflight_queue, + True, [KVPoll.Success, KVPoll.Failed], ) # if other ranks, do intersection with the previous rank's transferred rids @@ -788,6 +793,7 @@ def _pp_pd_get_transferred_ids(self: Scheduler): # 2. get the current stage's transferred reqs info curr_transferred_rids = self.get_rids( self.disagg_prefill_inflight_queue, + True, [KVPoll.Success, KVPoll.Failed], ) # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids) @@ -929,7 +935,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler): bmbs[mb_id] = bootstrapped_rids self._pp_commit_comm_work(send_bootstrapped_work) - transferred_rids = self._pp_pd_get_transferred_ids() + transferred_rids = self._pp_pd_get_prefill_transferred_ids() self._pp_commit_comm_work(send_transfer_work) tmbs[mb_id] = transferred_rids @@ -1043,3 +1049,326 @@ def event_loop_pp_disagg_prefill(self: Scheduler): self.check_tree_cache() self.new_token_ratio = self.init_new_token_ratio self.maybe_sleep_on_idle() + + def _pp_pd_get_retract_ids(self: Scheduler): + # communicate pre-consensus retracted reqs + curr_retract_rids = [ + req.rid for req in self.disagg_decode_prealloc_queue.retracted_queue + ] + if self.pp_group.is_first_rank: + # First rank, get all retracted req ids + return curr_retract_rids + else: + # Other ranks, receive the retracted reqs info from the previous rank and ensure the consensus + prev_retract_rids = self._pp_recv_pyobj_from_prev_stage() + return list(set(prev_retract_rids) & set(curr_retract_rids)) + + def _pp_pd_get_prealloc_ids(self: Scheduler): + # communicate pre-consensus prealloc reqs + if self.pp_group.is_first_rank: + # First rank, pop the preallocated reqs from the prealloc queue + good_prealloc_rids, bad_prealloc_rids = self.get_rids( + self.disagg_decode_prealloc_queue.queue, + False, + [KVPoll.WaitingForInput], + [KVPoll.Failed], + ) + else: + # Other ranks, receive the preallocated reqs info from the previous rank and ensure the consensus + prev_prealloc_rids = self._pp_recv_pyobj_from_prev_stage() + prev_good_prealloc_rids, prev_bad_prealloc_rids = prev_prealloc_rids + curr_good_prealloc_rids, curr_bad_prealloc_rids = self.get_rids( + self.disagg_decode_prealloc_queue.queue, + False, + [KVPoll.WaitingForInput], + [KVPoll.Failed], + ) + good_prealloc_rids = list( + set(prev_good_prealloc_rids) & set(curr_good_prealloc_rids) + ) + bad_prealloc_rids = list( + set(prev_bad_prealloc_rids) | set(curr_bad_prealloc_rids) + ) + return [good_prealloc_rids, bad_prealloc_rids] + + def _pp_pd_get_decode_transferred_ids(self: Scheduler): + # get the current stage transfer success + if self.pp_group.is_first_rank: + transferred_rids = self.get_rids( + self.disagg_decode_transfer_queue, + False, + [KVPoll.Success, KVPoll.Failed], + ) + # if other ranks, do intersection with the previous rank's transferred rids + else: + # 2 (Release): Receive the transferred rids from the previous rank + # 1. recv previous stage's transferred reqs info + prev_transferred_rids = self._pp_recv_pyobj_from_prev_stage() + # 2. get the current stage's transferred reqs info + curr_transferred_rids = self.get_rids( + self.disagg_decode_transfer_queue, + False, + [KVPoll.Success, KVPoll.Failed], + ) + # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids) + transferred_rids = list( + set(prev_transferred_rids) & set(curr_transferred_rids) + ) + return transferred_rids + + # from process_decode_queue + def process_retract_queue(self: Scheduler, retract_rids: Optional[List[str]]): + # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps + resumed_reqs, has_retracted_req = ( + self.disagg_decode_prealloc_queue.resume_retracted_reqs(retract_rids) + ) + self.waiting_queue.extend(resumed_reqs) + return [req.rid for req in resumed_reqs], has_retracted_req + + # from process_decode_queue + def process_prealloc_queue( + self: Scheduler, prealloc_rids: Optional[List[str]], has_retracted_req: bool + ): + if has_retracted_req: + # if there are still retracted requests, we do not allocate new requests + return None + + # TODO: figure out if we need polling_count, probably do not since we do not need to poll + # in particular we do not call _update_handshake_waiters + + if prealloc_rids is not None: + ( + good_consensus_prealloc_rids, + bad_consensus_prealloc_rids, + ) = prealloc_rids + good_reqs, failed_reqs = self.disagg_decode_prealloc_queue.pop_preallocated( + rids_to_check=good_consensus_prealloc_rids + + bad_consensus_prealloc_rids, + ) + self.disagg_decode_transfer_queue.extend(good_reqs) + return [[req.rid for req in good_reqs], [req.rid for req in failed_reqs]] + return None + + def process_decode_transfer_queue( + self: Scheduler, release_rids: Optional[List[str]] + ): + if release_rids is not None: + released_reqs = self.disagg_decode_transfer_queue.pop_transferred( + release_rids + ) + self.waiting_queue.extend(released_reqs) + return [req.rid for req in released_reqs] + return None + + @DynamicGradMode() + def event_loop_pp_disagg_decode(self: Scheduler): + self.pp_loop_size: int = self.pp_size + self.server_args.pp_async_batch_depth + mbs = [None] * self.pp_loop_size + last_mbs = [None] * self.pp_loop_size + self.running_mbs = [ + ScheduleBatch(reqs=[], batch_is_full=False) + for _ in range(self.pp_loop_size) + ] + mb_metadata: List[Optional[PPBatchMetadata]] = [None] * self.pp_loop_size + pp_outputs: Optional[PPProxyTensors] = None + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]] = deque() + + # PD additional + + # consensus rids + consensus_retract_rids: Optional[List[str]] = None + consensus_prealloc_rids: Optional[List[str]] = None + release_rids: Optional[List[str]] = None + + rmbs = [None] * self.pp_loop_size + bmbs = [None] * self.pp_loop_size + tmbs = [None] * self.pp_loop_size + + send_req_work = [] + + # send info to reach consensus + send_retract_work = [] + send_prealloc_work = [] + send_transfer_work = [] + + # send consensus info + send_consensus_retract_work = [] + send_consensus_prealloc_work = [] + send_release_work = [] + + send_proxy_work = [] + send_output_work = [] + + while True: + server_is_idle = True + for mb_id in range(self.pp_loop_size): + self.running_batch = self.running_mbs[mb_id] + self.last_batch = last_mbs[mb_id] + next_first_rank_mb_id = (mb_id + self.pp_size) % self.pp_loop_size + next_mb_id = (mb_id + 1) % self.pp_loop_size + + next_pp_outputs = None + next_release_rids = None + next_consensus_prealloc_rids = None + next_consensus_retract_rids = None + d2h_event = None + next_batch_result = None + + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + + if not self.pp_group.is_last_rank: + self._pp_commit_comm_work(send_req_work) + + # reaching consensus through PP ranks + retract_rids = self._pp_pd_get_retract_ids() + rmbs[mb_id] = retract_rids + self._pp_commit_comm_work(send_retract_work) + + prealloc_rids = self._pp_pd_get_prealloc_ids() + bmbs[mb_id] = prealloc_rids + self._pp_commit_comm_work(send_prealloc_work) + + transferred_rids = self._pp_pd_get_decode_transferred_ids() + tmbs[mb_id] = transferred_rids + self._pp_commit_comm_work(send_transfer_work) + + # get batch to run and proxy tensors if needed + batch = self.get_next_disagg_decode_batch_to_run() + mbs[mb_id] = batch + self.running_mbs[mb_id] = self.running_batch + + self.cur_batch: Optional[ScheduleBatch] = mbs[mb_id] + if self.cur_batch: + server_is_idle = False + pp_proxy_tensors = self._pp_recv_proxy_tensors() + + # early send output if possible + if self.server_args.pp_async_batch_depth > 0: + self._pp_commit_comm_work(work=send_output_work) + next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_first_rank_mb_id, + next_mb_id, + mbs, + mb_metadata, + last_rank_comm_queue, + pp_outputs, + ) + ) + self._pp_commit_comm_work(send_proxy_work) + # run batch + if self.cur_batch: + result, event = self._pp_launch_batch( + mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue + ) + # regular send output + if self.server_args.pp_async_batch_depth == 0: + self._pp_commit_comm_work(work=send_output_work) + next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_first_rank_mb_id, + next_mb_id, + mbs, + mb_metadata, + last_rank_comm_queue, + pp_outputs, + ) + ) + + # reach consensus on last rank and send to PP=0 + # otherwise, just pass along previous consensus + send_consensus_retract_work, consensus_retract_rids = ( + self._pp_pd_send_consensus_bootstrapped_ids( # reuse the function + rmbs, + next_first_rank_mb_id, + consensus_retract_rids, + retract_rids, + ) + ) + + send_consensus_prealloc_work, consensus_prealloc_rids = ( + self._pp_pd_send_consensus_bootstrapped_ids( # reuse the function + bmbs, + next_first_rank_mb_id, + consensus_prealloc_rids, + prealloc_rids, + ) + ) + + send_release_work, release_rids = ( + self._pp_pd_send_consensus_release_ids( + tmbs, next_first_rank_mb_id, release_rids, transferred_rids + ) + ) + + # from process_decode_queue + if self.server_args.disaggregation_decode_enable_offload_kvcache: + self.decode_offload_manager.check_offload_progress() + + has_retracted_req = False + if rmbs[next_mb_id] is not None: + next_consensus_retract_rids = self._pp_recv_pyobj_from_prev_stage() + next_consensus_retract_rids, has_retracted_req = ( + self.process_retract_queue( # TODO: implement this + next_consensus_retract_rids + ) + ) + self._pp_commit_comm_work(send_consensus_retract_work) + + if bmbs[next_mb_id] is not None: + next_consensus_prealloc_rids = self._pp_recv_pyobj_from_prev_stage() + next_consensus_prealloc_rids = self.process_prealloc_queue( + next_consensus_prealloc_rids, has_retracted_req + ) + self._pp_commit_comm_work(send_consensus_prealloc_work) + + if tmbs[next_mb_id] is not None: + next_release_rids = self._pp_recv_pyobj_from_prev_stage() + next_release_rids = self.process_decode_transfer_queue( + next_release_rids + ) + self._pp_commit_comm_work(send_release_work) + + # post-process the coming microbatch + if mbs[next_mb_id] is not None: + d2h_event.synchronize() + self._pp_process_batch_result( + mbs[next_mb_id], + next_batch_result, + ) + last_mbs[next_mb_id] = mbs[next_mb_id] + + if not self.pp_group.is_last_rank: + send_req_work = self._pp_send_pyobj_to_next_stage( + recv_reqs, async_send=True + ) + send_prealloc_work = self._pp_send_pyobj_to_next_stage( + prealloc_rids, async_send=True + ) + send_transfer_work = self._pp_send_pyobj_to_next_stage( + transferred_rids, async_send=True + ) + if self.cur_batch: + torch.cuda.current_stream().wait_event(event) + send_proxy_work = self._pp_send_dict_to_next_stage( + result.pp_hidden_states_proxy_tensors.tensors, + async_send=True, + ) + + if hasattr(self, "delayed_weight_sync_fn"): + self.delayed_weight_sync_fn() + self.delayed_weight_sync_fn = None + + pp_outputs = next_pp_outputs + release_rids = next_release_rids + consensus_prealloc_rids = next_consensus_prealloc_rids + + self.running_batch.batch_is_full = False + + # When the server is idle, self-check and re-init some states + if server_is_idle and len(self.disagg_decode_transfer_queue) == 0: + self.check_memory() + self.check_tree_cache() + self.new_token_ratio = self.init_new_token_ratio + self.maybe_sleep_on_idle() From 5bb4a14891f96fe7804bbb6ac9b59b405997f378 Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Fri, 28 Nov 2025 20:30:26 -0800 Subject: [PATCH 2/5] pp x pd decode startup --- python/sglang/srt/disaggregation/decode.py | 8 +- python/sglang/srt/entrypoints/engine.py | 4 +- python/sglang/srt/managers/scheduler.py | 6 +- .../sglang/srt/managers/scheduler_pp_mixin.py | 102 +++++++++++------- 4 files changed, 74 insertions(+), 46 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 3b37c427d6a..0ccbea0e836 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -400,7 +400,7 @@ def _update_handshake_waiters( ) for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): - if rids_to_check is not None and decode_req.rid not in rids_to_check: + if rids_to_check is not None and decode_req.req.rid not in rids_to_check: continue if poll == KVPoll.Bootstrapping: @@ -445,7 +445,7 @@ def pop_preallocated( ) # First, remove all failed requests from the queue for i, decode_req in enumerate(self.queue): - if rids_to_check is not None and decode_req.rid not in rids_to_check: + if rids_to_check is not None and decode_req.req.rid not in rids_to_check: continue if isinstance(decode_req.req.finished_reason, FINISH_ABORT): self.scheduler.stream_output( @@ -456,7 +456,7 @@ def pop_preallocated( # Then, preallocate the remaining requests if possible for i, decode_req in enumerate(self.queue): - if rids_to_check is not None and decode_req.rid not in rids_to_check: + if rids_to_check is not None and decode_req.req.rid not in rids_to_check: continue if i in indices_to_remove: @@ -759,7 +759,7 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req transferred_reqs = [] indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): - if rids_to_check is not None and decode_req.rid not in rids_to_check: + if rids_to_check is not None and decode_req.req.rid not in rids_to_check: continue if poll == KVPoll.Failed: error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 19c4e29a769..eb7024d20f0 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -727,7 +727,7 @@ def _set_envs_and_config(server_args: ServerArgs): if server_args.attention_backend == "flashinfer": assert_pkg_version( "flashinfer_python", - "0.5.3", + "0.5.2", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", @@ -735,7 +735,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"): assert_pkg_version( "sgl-kernel", - "0.3.17.post2", + "0.3.17.post1", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3421cdeaf64..3aa813ee9ef 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -877,6 +877,7 @@ def init_disaggregation(self): bootstrap_port=self.server_args.disaggregation_bootstrap_port, max_total_num_tokens=self.max_total_num_tokens, prefill_pp_size=self.server_args.disaggregation_prefill_pp, + pp_rank=self.pp_rank, num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens, transfer_backend=self.transfer_backend, ) @@ -2698,7 +2699,10 @@ def run_scheduler_process( if scheduler.enable_overlap: scheduler.event_loop_overlap_disagg_decode() else: - scheduler.event_loop_normal_disagg_decode() + if server_args.pp_size > 1: + scheduler.event_loop_pp_disagg_decode() + else: + scheduler.event_loop_normal_disagg_decode() except Exception: traceback = get_exception_traceback() diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 8049125119b..6332ac754da 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -503,12 +503,13 @@ def _pp_send_output_to_next_stage( # send ready PP output to rank 0 if mbs[next_first_rank_mb_id] is not None: q_event, pp_outputs_to_send = last_rank_comm_queue.popleft() - torch.cuda.current_stream().wait_event(q_event) - with torch.profiler.record_function("send_res_dict_to_next_stage"): - send_output_work = self._pp_send_dict_to_next_stage( - pp_outputs_to_send.tensors, - async_send=True, - ) + if not mbs[next_first_rank_mb_id].forward_mode.is_prebuilt(): + torch.cuda.current_stream().wait_event(q_event) + with torch.profiler.record_function("send_res_dict_to_next_stage"): + send_output_work = self._pp_send_dict_to_next_stage( + pp_outputs_to_send.tensors, + async_send=True, + ) # send the outputs from the last round to let the next stage worker run post processing if not self.pp_group.is_last_rank: if pp_outputs: @@ -540,14 +541,19 @@ def _pp_send_recv_and_preprocess_output_tensors( if mbs[next_mb_id] is not None: with torch.profiler.record_function("recv_res_dict_from_prev_stage"): - next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage()) - with self.copy_stream_ctx: - self.copy_stream.wait_stream(self.default_stream) - batch_result = self._pp_prep_batch_result( - mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs - ) - d2h_event = torch.cuda.Event() - d2h_event.record(torch.cuda.current_stream()) + next_pp_outputs = None + if not mbs[next_mb_id].forward_mode.is_prebuilt(): + next_pp_outputs = PPProxyTensors( + self._pp_recv_dict_from_prev_stage() + ) + if not mbs[next_mb_id].forward_mode.is_prebuilt(): + with self.copy_stream_ctx: + self.copy_stream.wait_stream(self.default_stream) + batch_result = self._pp_prep_batch_result( + mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs + ) + d2h_event = torch.cuda.Event() + d2h_event.record(torch.cuda.current_stream()) return next_pp_outputs, batch_result, d2h_event, send_output_work @@ -580,20 +586,20 @@ def _pp_launch_batch( return result, event def get_rids( - self: Scheduler, req_queue: List[Req], sender: bool, *poll_statuses_group + self: Scheduler, req_queue: List[Req], is_send: bool, *poll_statuses_group ): """ Used by PP, get the required rids with the given poll statuses. """ polls = poll_and_all_reduce( - [req.disagg_kv_sender if sender else req.kv_receiver for req in req_queue], + [req.disagg_kv_sender if is_send else req.kv_receiver for req in req_queue], self.tp_worker.get_attention_tp_cpu_group(), ) rids: List = [] for poll_statuses in poll_statuses_group: rids.append( [ - req.rid + req.rid if is_send else req.req.rid for req, poll in zip(req_queue, polls) if poll in poll_statuses ] @@ -1095,7 +1101,7 @@ def _pp_pd_get_decode_transferred_ids(self: Scheduler): # get the current stage transfer success if self.pp_group.is_first_rank: transferred_rids = self.get_rids( - self.disagg_decode_transfer_queue, + self.disagg_decode_transfer_queue.queue, False, [KVPoll.Success, KVPoll.Failed], ) @@ -1106,7 +1112,7 @@ def _pp_pd_get_decode_transferred_ids(self: Scheduler): prev_transferred_rids = self._pp_recv_pyobj_from_prev_stage() # 2. get the current stage's transferred reqs info curr_transferred_rids = self.get_rids( - self.disagg_decode_transfer_queue, + self.disagg_decode_transfer_queue.queue, False, [KVPoll.Success, KVPoll.Failed], ) @@ -1118,12 +1124,14 @@ def _pp_pd_get_decode_transferred_ids(self: Scheduler): # from process_decode_queue def process_retract_queue(self: Scheduler, retract_rids: Optional[List[str]]): - # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps - resumed_reqs, has_retracted_req = ( - self.disagg_decode_prealloc_queue.resume_retracted_reqs(retract_rids) - ) - self.waiting_queue.extend(resumed_reqs) - return [req.rid for req in resumed_reqs], has_retracted_req + if retract_rids is not None: + # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps + resumed_reqs, has_retracted_req = ( + self.disagg_decode_prealloc_queue.resume_retracted_reqs(retract_rids) + ) + self.waiting_queue.extend(resumed_reqs) + return [req.rid for req in resumed_reqs], has_retracted_req + return None, False # from process_decode_queue def process_prealloc_queue( @@ -1146,7 +1154,10 @@ def process_prealloc_queue( + bad_consensus_prealloc_rids, ) self.disagg_decode_transfer_queue.extend(good_reqs) - return [[req.rid for req in good_reqs], [req.rid for req in failed_reqs]] + return [ + [req.req.rid for req in good_reqs], + [req.req.rid for req in failed_reqs], + ] return None def process_decode_transfer_queue( @@ -1208,9 +1219,9 @@ def event_loop_pp_disagg_decode(self: Scheduler): next_mb_id = (mb_id + 1) % self.pp_loop_size next_pp_outputs = None - next_release_rids = None - next_consensus_prealloc_rids = None next_consensus_retract_rids = None + next_consensus_prealloc_rids = None + next_release_rids = None d2h_event = None next_batch_result = None @@ -1241,7 +1252,9 @@ def event_loop_pp_disagg_decode(self: Scheduler): self.cur_batch: Optional[ScheduleBatch] = mbs[mb_id] if self.cur_batch: server_is_idle = False - pp_proxy_tensors = self._pp_recv_proxy_tensors() + pp_proxy_tensors = None + if not self.cur_batch.forward_mode.is_prebuilt(): + pp_proxy_tensors = self._pp_recv_proxy_tensors() # early send output if possible if self.server_args.pp_async_batch_depth > 0: @@ -1310,9 +1323,7 @@ def event_loop_pp_disagg_decode(self: Scheduler): if rmbs[next_mb_id] is not None: next_consensus_retract_rids = self._pp_recv_pyobj_from_prev_stage() next_consensus_retract_rids, has_retracted_req = ( - self.process_retract_queue( # TODO: implement this - next_consensus_retract_rids - ) + self.process_retract_queue(next_consensus_retract_rids) ) self._pp_commit_comm_work(send_consensus_retract_work) @@ -1332,24 +1343,28 @@ def event_loop_pp_disagg_decode(self: Scheduler): # post-process the coming microbatch if mbs[next_mb_id] is not None: - d2h_event.synchronize() - self._pp_process_batch_result( - mbs[next_mb_id], - next_batch_result, - ) + if not mbs[next_mb_id].forward_mode.is_prebuilt(): + d2h_event.synchronize() + self._pp_process_batch_result( + mbs[next_mb_id], + next_batch_result, + ) last_mbs[next_mb_id] = mbs[next_mb_id] if not self.pp_group.is_last_rank: send_req_work = self._pp_send_pyobj_to_next_stage( recv_reqs, async_send=True ) + send_retract_work = self._pp_send_pyobj_to_next_stage( + retract_rids, async_send=True + ) send_prealloc_work = self._pp_send_pyobj_to_next_stage( prealloc_rids, async_send=True ) send_transfer_work = self._pp_send_pyobj_to_next_stage( transferred_rids, async_send=True ) - if self.cur_batch: + if self.cur_batch and not self.cur_batch.forward_mode.is_prebuilt(): torch.cuda.current_stream().wait_event(event) send_proxy_work = self._pp_send_dict_to_next_stage( result.pp_hidden_states_proxy_tensors.tensors, @@ -1362,12 +1377,21 @@ def event_loop_pp_disagg_decode(self: Scheduler): pp_outputs = next_pp_outputs release_rids = next_release_rids + consensus_retract_rids = next_consensus_retract_rids consensus_prealloc_rids = next_consensus_prealloc_rids self.running_batch.batch_is_full = False # When the server is idle, self-check and re-init some states - if server_is_idle and len(self.disagg_decode_transfer_queue) == 0: + queue_size = ( + len(self.waiting_queue) + + len(self.disagg_decode_transfer_queue.queue) + + len(self.disagg_decode_prealloc_queue.queue) + ) + if self.server_args.disaggregation_decode_enable_offload_kvcache: + queue_size += len(self.decode_offload_manager.ongoing_offload) + + if server_is_idle and queue_size == 0: self.check_memory() self.check_tree_cache() self.new_token_ratio = self.init_new_token_ratio From d572676e456c1f024cfec45203aa66b0ecec5c04 Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Tue, 2 Dec 2025 13:38:36 -0800 Subject: [PATCH 3/5] mooncake backend support --- .../sglang/srt/disaggregation/common/conn.py | 38 +++++++++++++------ .../srt/disaggregation/mooncake/conn.py | 4 +- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 636ff72db21..7cafbd5a865 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -153,27 +153,34 @@ def _connect(self, endpoint: str, is_ipv6: bool = False): def get_mha_kv_ptrs_with_pp( self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int] ) -> Tuple[List[int], List[int], List[int], List[int], int]: - # pp is not supported on the decode side yet start_layer = self.kv_args.prefill_start_layer num_kv_layers = len(src_kv_ptrs) // 2 end_layer = start_layer + num_kv_layers dst_num_total_layers = len(dst_kv_ptrs) // 2 src_k_ptrs = src_kv_ptrs[:num_kv_layers] src_v_ptrs = src_kv_ptrs[num_kv_layers:] - dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] - dst_v_ptrs = dst_kv_ptrs[ - dst_num_total_layers + start_layer : dst_num_total_layers + end_layer - ] + if num_kv_layers == dst_num_total_layers: + dst_k_ptrs = dst_kv_ptrs[:dst_num_total_layers] + dst_v_ptrs = dst_kv_ptrs[dst_num_total_layers:] + else: + # Decode pp size should be equal to prefill pp size or 1 + dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] + dst_v_ptrs = dst_kv_ptrs[ + dst_num_total_layers + start_layer : dst_num_total_layers + end_layer + ] layers_current_pp_stage = len(src_k_ptrs) return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage def get_mla_kv_ptrs_with_pp( self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int] ) -> Tuple[List[int], List[int], int]: - # pp is not supported on the decode side yet start_layer = self.kv_args.prefill_start_layer end_layer = start_layer + len(src_kv_ptrs) - sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer] + if len(src_kv_ptrs) == len(dst_kv_ptrs): + sliced_dst_kv_ptrs = dst_kv_ptrs + else: + # Decode pp size should be equal to prefill pp size or 1 + sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer] layers_current_pp_stage = len(src_kv_ptrs) return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage @@ -273,8 +280,7 @@ def __init__( self.bootstrap_addr ] - # Currently, we don't allow prefill instance and decode instance to - # have different TP sizes per DP rank, except for models using MLA. + # Handling for PD with different TP sizes per DP rank if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size: self.target_tp_rank = ( self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size @@ -335,9 +341,19 @@ def __init__( else: self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size - # FIXME: alias here: target_dp_group -> prefill_dp_rank self.target_dp_group = self.prefill_dp_rank + # Decode pp size should be equal to prefill pp size or 1 + assert ( + self.kv_mgr.pp_size == self.prefill_pp_size or self.kv_mgr.pp_size == 1 + ), ( + f"Decode pp size ({self.kv_mgr.pp_size}) should be equal to prefill pp size ({self.prefill_pp_size}) or 1", + ) + if self.prefill_pp_size == self.kv_mgr.pp_size: + self.target_pp_ranks = [self.kv_mgr.pp_rank] + else: + self.target_pp_ranks = [rank for rank in range(self.prefill_pp_size)] + self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( self.required_prefill_response_num ) @@ -349,7 +365,7 @@ def __init__( if bootstrap_key not in self.kv_mgr.connection_pool: bootstrap_infos = [] for target_tp_rank in self.target_tp_ranks: - for target_pp_rank in range(self.prefill_pp_size): + for target_pp_rank in self.target_pp_ranks: bootstrap_info = self._get_bootstrap_info_from_server( target_tp_rank, self.target_dp_group, target_pp_rank ) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 3fe895ed196..d4414d0843e 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -283,7 +283,7 @@ def _send_kvcache_generic( layers_params = None - # pp is not supported on the decode side yet + # Decode pp size should be equal to prefill pp size or 1 if self.is_mla_backend: src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = ( self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs) @@ -1199,7 +1199,7 @@ def _register_kv_args(self): packed_state_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs ) - # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet + # Note(shangming): No need to add pp rank here since decode pp size should be equal to prefill pp size or 1 tp_rank = self.kv_mgr.kv_args.engine_rank kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0] dst_tp_rank = str(tp_rank).encode("ascii") From 94d56a5edb82ae2fdf8220c7eecee203436bec5d Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Wed, 3 Dec 2025 14:54:32 -0800 Subject: [PATCH 4/5] fix nixl --- python/sglang/srt/disaggregation/nixl/conn.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index c61c3e98009..547f1db8d56 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -323,12 +323,11 @@ def send_kvcache( src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = ( self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) ) - kv_item_len = self.kv_args.kv_item_lens[0] layers_params = [ ( src_kv_ptrs[layer_id], dst_kv_ptrs[layer_id], - kv_item_len, + self.kv_args.kv_item_lens[layer_id], ) for layer_id in range(layers_current_pp_stage) ] @@ -337,19 +336,18 @@ def send_kvcache( self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) ) - kv_item_len = self.kv_args.kv_item_lens[0] layers_params = [ ( src_k_ptrs[layer_id], dst_k_ptrs[layer_id], - kv_item_len, + self.kv_args.kv_item_lens[layer_id], ) for layer_id in range(layers_current_pp_stage) ] + [ ( src_v_ptrs[layer_id], dst_v_ptrs[layer_id], - kv_item_len, + self.kv_args.kv_item_lens[layer_id], ) for layer_id in range(layers_current_pp_stage) ] @@ -608,7 +606,7 @@ def add_transfer_request( handles.append(kv_xfer_handle) # Only the last chunk we need to send the aux data. - if is_last and self.pp_group.is_last_rank: + if is_last: assert aux_index is not None aux_xfer_handle = self.send_aux( req.agent_name, From 4433a9d2a87f1fe1e49b9162608f7744840da524 Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 4 Dec 2025 13:56:45 -0800 Subject: [PATCH 5/5] fix retract req --- python/sglang/srt/disaggregation/decode.py | 16 ++---- python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/managers/scheduler_pp_mixin.py | 53 +++++++++---------- 3 files changed, 31 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 0ccbea0e836..9423b8cd9f1 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -234,7 +234,6 @@ def _init_kv_manager(self) -> BaseKVManager: kv_args.engine_rank = self.tp_rank % (attn_tp_size) kv_args.decode_tp_size = attn_tp_size - # Note(shangming): pp is not supported on the decode side yet, so its rank is fixed to 0 kv_args.pp_rank = self.pp_rank kv_args.system_dp_rank = self.scheduler.dp_rank kv_args.prefill_pp_size = self.prefill_pp_size @@ -300,6 +299,7 @@ def add(self, req: Req, is_retracted: bool = False) -> None: return if is_retracted: + req.retraction_mb_id = None self.retracted_queue.append(req) else: if req.bootstrap_host == FAKE_BOOTSTRAP_HOST: @@ -340,7 +340,7 @@ def extend(self, reqs: List[Req], is_retracted: bool = False) -> None: def resume_retracted_reqs( self, rids_to_check: Optional[List[str]] = None - ) -> Tuple[List[Req], bool]: + ) -> List[Req]: # TODO refactor the scheduling part, reuse with the unified engine logic as much as possible # allocate memory @@ -378,13 +378,7 @@ def resume_retracted_reqs( if i not in indices_to_remove ] - has_retracted_req = ( - bool(len(self.retracted_queue) > 0) - if rids_to_check is None - else any(req.rid in rids_to_check for req in self.retracted_queue) - ) - - return resumed_reqs, has_retracted_req + return resumed_reqs def _update_handshake_waiters( self, rids_to_check: Optional[List[str]] = None @@ -426,7 +420,7 @@ def _update_handshake_waiters( def pop_preallocated( self, rids_to_check: Optional[List[str]] = None - ) -> List[DecodeRequest]: + ) -> Tuple[List[DecodeRequest], List[DecodeRequest]]: """Pop the preallocated requests from the pending queue (FIFO).""" self._update_handshake_waiters(rids_to_check) @@ -968,7 +962,7 @@ def process_decode_queue(self: Scheduler): self.decode_offload_manager.check_offload_progress() # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps - resumed_reqs, _ = self.disagg_decode_prealloc_queue.resume_retracted_reqs() + resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs() self.waiting_queue.extend(resumed_reqs) if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0: # if there are still retracted requests, we do not allocate new requests diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6a7f44a310e..bbcf1cdc0cd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -651,6 +651,7 @@ def __init__( # The number of times this request has been retracted / preempted. self.retraction_count = 0 + self.retraction_mb_id = None # For metrics self.metrics_collector = metrics_collector diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 6332ac754da..a1cfc63834f 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -1056,13 +1056,19 @@ def event_loop_pp_disagg_prefill(self: Scheduler): self.new_token_ratio = self.init_new_token_ratio self.maybe_sleep_on_idle() - def _pp_pd_get_retract_ids(self: Scheduler): + def _pp_pd_get_retract_ids(self: Scheduler, mb_id: int): # communicate pre-consensus retracted reqs + for req in self.disagg_decode_prealloc_queue.retracted_queue: + # assign retracted reqs to the current microbatch + if req.retraction_mb_id is None: + req.retraction_mb_id = mb_id curr_retract_rids = [ - req.rid for req in self.disagg_decode_prealloc_queue.retracted_queue + req.rid + for req in self.disagg_decode_prealloc_queue.retracted_queue + if req.retraction_mb_id == mb_id ] if self.pp_group.is_first_rank: - # First rank, get all retracted req ids + # First rank, get all retracted req ids for the microbatch return curr_retract_rids else: # Other ranks, receive the retracted reqs info from the previous rank and ensure the consensus @@ -1122,27 +1128,20 @@ def _pp_pd_get_decode_transferred_ids(self: Scheduler): ) return transferred_rids - # from process_decode_queue def process_retract_queue(self: Scheduler, retract_rids: Optional[List[str]]): if retract_rids is not None: # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps - resumed_reqs, has_retracted_req = ( - self.disagg_decode_prealloc_queue.resume_retracted_reqs(retract_rids) + resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs( + retract_rids ) self.waiting_queue.extend(resumed_reqs) - return [req.rid for req in resumed_reqs], has_retracted_req - return None, False + return [req.rid for req in resumed_reqs] + return None - # from process_decode_queue - def process_prealloc_queue( - self: Scheduler, prealloc_rids: Optional[List[str]], has_retracted_req: bool - ): - if has_retracted_req: + def process_prealloc_queue(self: Scheduler, prealloc_rids: Optional[List[str]]): + if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0: # if there are still retracted requests, we do not allocate new requests - return None - - # TODO: figure out if we need polling_count, probably do not since we do not need to poll - # in particular we do not call _update_handshake_waiters + return [[], []] if prealloc_rids is not None: ( @@ -1189,10 +1188,10 @@ def event_loop_pp_disagg_decode(self: Scheduler): # consensus rids consensus_retract_rids: Optional[List[str]] = None consensus_prealloc_rids: Optional[List[str]] = None - release_rids: Optional[List[str]] = None + release_rids: Optional[List[str]] = None # consensus transferred rids rmbs = [None] * self.pp_loop_size - bmbs = [None] * self.pp_loop_size + pmbs = [None] * self.pp_loop_size tmbs = [None] * self.pp_loop_size send_req_work = [] @@ -1232,12 +1231,12 @@ def event_loop_pp_disagg_decode(self: Scheduler): self._pp_commit_comm_work(send_req_work) # reaching consensus through PP ranks - retract_rids = self._pp_pd_get_retract_ids() + retract_rids = self._pp_pd_get_retract_ids(mb_id) rmbs[mb_id] = retract_rids self._pp_commit_comm_work(send_retract_work) prealloc_rids = self._pp_pd_get_prealloc_ids() - bmbs[mb_id] = prealloc_rids + pmbs[mb_id] = prealloc_rids self._pp_commit_comm_work(send_prealloc_work) transferred_rids = self._pp_pd_get_decode_transferred_ids() @@ -1302,7 +1301,7 @@ def event_loop_pp_disagg_decode(self: Scheduler): send_consensus_prealloc_work, consensus_prealloc_rids = ( self._pp_pd_send_consensus_bootstrapped_ids( # reuse the function - bmbs, + pmbs, next_first_rank_mb_id, consensus_prealloc_rids, prealloc_rids, @@ -1315,22 +1314,20 @@ def event_loop_pp_disagg_decode(self: Scheduler): ) ) - # from process_decode_queue if self.server_args.disaggregation_decode_enable_offload_kvcache: self.decode_offload_manager.check_offload_progress() - has_retracted_req = False if rmbs[next_mb_id] is not None: next_consensus_retract_rids = self._pp_recv_pyobj_from_prev_stage() - next_consensus_retract_rids, has_retracted_req = ( - self.process_retract_queue(next_consensus_retract_rids) + next_consensus_retract_rids = self.process_retract_queue( + next_consensus_retract_rids ) self._pp_commit_comm_work(send_consensus_retract_work) - if bmbs[next_mb_id] is not None: + if pmbs[next_mb_id] is not None: next_consensus_prealloc_rids = self._pp_recv_pyobj_from_prev_stage() next_consensus_prealloc_rids = self.process_prealloc_queue( - next_consensus_prealloc_rids, has_retracted_req + next_consensus_prealloc_rids ) self._pp_commit_comm_work(send_consensus_prealloc_work)