Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from collections import deque
from dataclasses import dataclass
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Type, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union

import torch
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -199,6 +199,7 @@ def __init__(
bootstrap_port: int,
max_total_num_tokens: int,
prefill_pp_size: int,
pp_rank: int,
num_reserved_decode_tokens: int,
transfer_backend: TransferBackend,
):
Expand All @@ -220,6 +221,7 @@ def __init__(
self.bootstrap_port = bootstrap_port
self.max_total_num_tokens = max_total_num_tokens
self.prefill_pp_size = prefill_pp_size
self.pp_rank = pp_rank
self.num_reserved_decode_tokens = num_reserved_decode_tokens
self.transfer_backend = transfer_backend
# Queue for requests pending pre-allocation
Expand All @@ -236,8 +238,7 @@ def _init_kv_manager(self) -> BaseKVManager:
kv_args.engine_rank = self.tp_rank % (attn_tp_size)

kv_args.decode_tp_size = attn_tp_size
# Note(shangming): pp is not supported on the decode side yet, so its rank is fixed to 0
kv_args.pp_rank = 0
kv_args.pp_rank = self.pp_rank
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.prefill_pp_size = self.prefill_pp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = (
Expand Down Expand Up @@ -302,6 +303,7 @@ def add(self, req: Req, is_retracted: bool = False) -> None:
return

if is_retracted:
req.retraction_mb_id = None
self.retracted_queue.append(req)
else:
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
Expand Down Expand Up @@ -340,7 +342,9 @@ def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
for req in reqs:
self.add(req, is_retracted=is_retracted)

def resume_retracted_reqs(self) -> List[Req]:
def resume_retracted_reqs(
self, rids_to_check: Optional[List[str]] = None
) -> List[Req]:
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible

# allocate memory
Expand All @@ -349,6 +353,9 @@ def resume_retracted_reqs(self) -> List[Req]:
allocatable_tokens = self._allocatable_tokens(count_retracted=False)

for i, req in enumerate(self.retracted_queue):
if rids_to_check is not None and req.rid not in rids_to_check:
continue

if self.req_to_token_pool.available_size() <= 0:
break

Expand Down Expand Up @@ -377,7 +384,9 @@ def resume_retracted_reqs(self) -> List[Req]:

return resumed_reqs

def _update_handshake_waiters(self) -> None:
def _update_handshake_waiters(
self, rids_to_check: Optional[List[str]] = None
) -> None:
if not self.queue:
return

Expand All @@ -389,6 +398,9 @@ def _update_handshake_waiters(self) -> None:
)

for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
continue

if poll == KVPoll.Bootstrapping:
pass
elif poll == KVPoll.WaitingForInput:
Expand All @@ -410,10 +422,13 @@ def _update_handshake_waiters(self) -> None:
else:
raise ValueError(f"Unexpected poll case: {poll}")

def pop_preallocated(self) -> List[DecodeRequest]:
def pop_preallocated(
self, rids_to_check: Optional[List[str]] = None
) -> Tuple[List[DecodeRequest], List[DecodeRequest]]:
"""Pop the preallocated requests from the pending queue (FIFO)."""
self._update_handshake_waiters()
self._update_handshake_waiters(rids_to_check)

failed_reqs = []
preallocated_reqs = []
indices_to_remove = set()

Expand All @@ -428,14 +443,20 @@ def pop_preallocated(self) -> List[DecodeRequest]:
)
# First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue):
if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
continue
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
failed_reqs.append(decode_req)
indices_to_remove.add(i)

# Then, preallocate the remaining requests if possible
for i, decode_req in enumerate(self.queue):
if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
continue

if i in indices_to_remove:
continue

Expand Down Expand Up @@ -544,7 +565,7 @@ def pop_preallocated(self) -> List[DecodeRequest]:
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]

return preallocated_reqs
return preallocated_reqs, failed_reqs

@property
def num_tokens_pre_allocated(self):
Expand Down Expand Up @@ -726,7 +747,7 @@ def _commit_transfer_to_req(self, decode_req: DecodeRequest) -> None:
)
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()

def pop_transferred(self) -> List[Req]:
def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req]:
if not self.queue:
return []
polls = poll_and_all_reduce(
Expand All @@ -736,6 +757,8 @@ def pop_transferred(self) -> List[Req]:
transferred_reqs = []
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
continue
if poll == KVPoll.Failed:
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try:
Expand Down Expand Up @@ -958,7 +981,7 @@ def process_decode_queue(self: Scheduler):
self.polling_count = (self.polling_count + 1) % self.polling_interval

if self.polling_count % self.polling_interval == 0:
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
req_conns, _ = self.disagg_decode_prealloc_queue.pop_preallocated()
self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = (
self.disagg_decode_transfer_queue.pop_transferred()
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if server_args.attention_backend == "flashinfer":
assert_pkg_version(
"flashinfer_python",
"0.5.3",
"0.5.2",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def __init__(

# The number of times this request has been retracted / preempted.
self.retraction_count = 0
self.retraction_mb_id = None

# For metrics
self.metrics_collector = metrics_collector
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def init_disaggregation(self):
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
max_total_num_tokens=self.max_total_num_tokens,
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
pp_rank=self.pp_rank,
num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
transfer_backend=self.transfer_backend,
)
Expand Down Expand Up @@ -2728,7 +2729,10 @@ def run_scheduler_process(
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_decode()
else:
scheduler.event_loop_normal_disagg_decode()
if server_args.pp_size > 1:
scheduler.event_loop_pp_disagg_decode()
else:
scheduler.event_loop_normal_disagg_decode()

except Exception:
traceback = get_exception_traceback()
Expand Down
Loading