diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index ae9f7541d..6564ebefe 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -214,6 +214,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router" ) + parser.add_argument( + "--past_future_scheduler", + action="store_true", + help="""use past_future_scheduler for adaptive request new token len prediction, + override --router_token_ratio and --router_max_new_token_len (still used during warmup)""", + ) parser.add_argument( "--router_max_wait_tokens", diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 195de4148..42d5bf076 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -266,7 +266,7 @@ def get_all_prompt_metadata(self): class ChunkedPrefillReq(Req): _pack_ = 4 - def get_tuple_tokens(self, is_busy, router_max_new_token_len): + def get_tuple_tokens(self, is_busy, router_max_new_token_len, has_out_len_factor=1.1): args = get_env_start_args() # chuncked prefill 推理的过程中,存在很多模式的延迟 step 推理的控制, 用于 # 保证更好的包间数据或者是提升 dp 模式下prefill 的效率,但是在估计 token 显存 @@ -283,7 +283,7 @@ def get_tuple_tokens(self, is_busy, router_max_new_token_len): cur_max_new_token_len = self.sample_params.max_new_tokens else: cur_max_new_token_len = min( - self.sample_params.max_new_tokens, max(int(1.1 * has_out_len), router_max_new_token_len) + self.sample_params.max_new_tokens, max(int(has_out_len_factor * has_out_len), router_max_new_token_len) ) a_len = max(self.input_len + has_out_len + 1, self.shm_cur_kv_len + 1) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index be369d9ca..eae4843f8 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -649,9 +649,9 @@ async def recycle_resource_loop(self): continue logger.info( - f"left req id {req_status.group_req_objs.group_req_id}" - f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " - f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" + f"left req id: {req_status.group_req_objs.group_req_id}, " + f"can release: {req_status.group_req_objs.shm_req_objs[0].can_released_mark}, " + f"refcount: {req_status.group_req_objs.shm_req_objs[0].ref_count}" ) return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 03fc694c9..91c550ff1 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -22,6 +22,7 @@ from .shm_reqs_io_buffer import ShmReqsIOBuffer from lightllm.utils.log_utils import init_logger, log_time_ready from lightllm.server.router.token_load import TokenLoad +from lightllm.server.router.req_queue.chunked_prefill.impl_past_future import PastFutureQueue from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.common.mem_manager import ReadOnlyStaticsMemoryManager @@ -319,6 +320,8 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch): def _filter_reqs_from_running_batch(self): if self.running_batch is not None: + if isinstance(self.req_queue, PastFutureQueue): + self.req_queue.record_finished_len_from_batch(self.running_batch) self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 10867b6e5..831a38688 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -1,29 +1,37 @@ from .chunked_prefill.impl_for_pd_decode import QueueForPDDecode from .chunked_prefill.impl import ChunkedPrefillQueue from .chunked_prefill.beam_impl import ChunkedBeamContinuesBatchQueue +from .chunked_prefill.impl_past_future import PastFutureQueue from .dp_base_queue import DpQueue def _get_req_queue_class(args, router, dp_size_in_node: int): + if args.past_future_scheduler: + if args.diverse_mode: + raise ValueError("Diverse mode is not supported with past future scheduler yet") + chunked_prefill_queue_impl = PastFutureQueue + else: + chunked_prefill_queue_impl = ChunkedPrefillQueue + if args.diverse_mode: return ChunkedBeamContinuesBatchQueue if args.token_healing_mode: - return ChunkedPrefillQueue + return chunked_prefill_queue_impl if args.output_constraint_mode != "none": - return ChunkedPrefillQueue + return chunked_prefill_queue_impl if args.first_token_constraint_mode: - return ChunkedPrefillQueue + return chunked_prefill_queue_impl if args.run_mode == "decode": return QueueForPDDecode if args.run_mode == "prefill": - return ChunkedPrefillQueue + return chunked_prefill_queue_impl if args.disable_chunked_prefill: # 虽然也使用chuncked prefill queue 但是由于 args.chunked_prefill_size = args.max_req_total_len # 所以调度的实际行为类似过去的 continues batch 调度,所以将两种调度的实现统一为一种实现,减少代码重复。 - return ChunkedPrefillQueue + return chunked_prefill_queue_impl else: - return ChunkedPrefillQueue + return chunked_prefill_queue_impl def build_req_queue(args, router, dp_size_in_node: int): diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 3730300a5..96138510c 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -21,8 +21,7 @@ def _init_cache_list(self, current_batch: Batch, is_busy): self.cache_len_list = [] return - # @calculate_time(show=True, min_cost_ms=0.1) - def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens): + def _update_cache_len_list(self, req: Req, is_busy): self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis self.cache_len_list.sort(key=lambda x: -x[1]) @@ -32,6 +31,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens size_array = np.arange(1, len(self.cache_len_list) + 1, 1) need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + return need_max_token_num + + # @calculate_time(show=True, min_cost_ms=0.1) + def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens): + need_max_token_num = self._update_cache_len_list(req, is_busy) with g_router_lock.obj: ok_token_num = ( need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_past_future.py b/lightllm/server/router/req_queue/chunked_prefill/impl_past_future.py new file mode 100644 index 000000000..40882e493 --- /dev/null +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_past_future.py @@ -0,0 +1,80 @@ +import bisect +from collections import deque +import random +from typing import List, Tuple +import numpy as np +from ...batch import Batch, Req +from .impl import ChunkedPrefillQueue + + +class PastFutureQueue(ChunkedPrefillQueue): + WINDOW_SIZE = 200 + MINIMUM_SAMPLES = 200 + MAXIMUM_LISTS = 5 + REVERSED = 0.05 + COMPLIANCE_IS_BUSY_FLAG = False + + def __init__(self, args, router, dp_index, dp_size_in_node) -> None: + super().__init__(args, router, dp_index, dp_size_in_node) + initial_len = args.router_max_new_token_len + self.history_output_len = deque([initial_len] * (self.WINDOW_SIZE // 2), maxlen=self.WINDOW_SIZE) + + def _sample_cache_list(self, reqs: List[Req], is_busy, samples=1) -> List[List[Tuple[int, int]]]: + cache_len_lists = [[] for _ in range(samples)] + his_Lo = sorted(self.history_output_len) + for req in reqs: + dl = req.shm_cur_output_len + pos = bisect.bisect(his_Lo, dl) + + sample_range = [dl] + his_Lo[pos:] + [req.sample_params.max_new_tokens] # at least 2 value + + for i in range(samples): + random_p = np.random.random() * (len(sample_range) - 1) + l_pos = int(random_p) + l_val, r_val = sample_range[l_pos : l_pos + 2] + + # Linear interpolation + sampled = round(l_val + (r_val - l_val) * (random_p - l_pos)) + cache_len_lists[i].append( + req.get_tuple_tokens(is_busy and self.COMPLIANCE_IS_BUSY_FLAG, sampled, has_out_len_factor=1.0) + ) + + return cache_len_lists + + def _calc_max_token_num_needed(self, cache_len_list: List[Tuple[int, int]]) -> int: + cache_len_list.sort(key=lambda x: -x[1]) + + left_out_len_array = np.array([e[1] for e in cache_len_list]) + has_run_len_array = np.array([e[0] for e in cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(cache_len_list) + 1, 1) + + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + return need_max_token_num + + def _init_cache_list(self, current_batch: Batch, is_busy): + if current_batch is not None: + n_lists = min(self.MAXIMUM_LISTS, int(self.MINIMUM_SAMPLES / len(current_batch.reqs)) + 1) + local_reqs = [req for req in current_batch.reqs if req.sample_params.suggested_dp_index == self.dp_index] + self._cache_len_lists = self._sample_cache_list(local_reqs, is_busy, samples=n_lists) + else: + self._cache_len_lists = [[]] + self.cache_len_list = self._cache_len_lists[0] # keep compatibility + + def _update_cache_len_list(self, req: Req, is_busy): + need_max_token_nums = [] + for li in self._cache_len_lists: + newreq_output_len_sample = random.choice(self.history_output_len) + li.append( + req.get_tuple_tokens( + is_busy and self.COMPLIANCE_IS_BUSY_FLAG, newreq_output_len_sample, has_out_len_factor=1.0 + ) + ) + need_max_token_nums.append(self._calc_max_token_num_needed(li)) + need_max_token_num = np.max(need_max_token_nums) + return need_max_token_num + + def record_finished_len_from_batch(self, batch: Batch): + for req in batch.reqs: + if req.shm_infer_released: + self.history_output_len.append(req.shm_cur_output_len)