2323import traceback
2424import weakref
2525from concurrent .futures import ThreadPoolExecutor
26- from typing import Dict , List , Optional , Tuple
26+ from typing import Dict , List , Optional , Tuple , Union
2727
2828import numpy as np
2929import paddle
@@ -142,14 +142,16 @@ def __init__(self, cfg, start_queue=True):
142142 def start (self ):
143143 self .running = True
144144 if envs .ENABLE_V1_KVCACHE_SCHEDULER :
145- self .insert_task_to_worker_thread = threading .Thread (target = self ._scheduler_task_to_worker_v1 , daemon = True )
145+ self .insert_task_to_worker_thread = threading .Thread (
146+ target = self ._schedule_request_to_worker_v1 , daemon = True
147+ )
146148 else :
147- self .insert_task_to_worker_thread = threading .Thread (target = self ._insert_task_to_worker , daemon = True )
149+ self .insert_task_to_worker_thread = threading .Thread (target = self ._schedule_request_to_worker , daemon = True )
148150 self .insert_task_to_worker_thread .start ()
149151 self .token_processor .tasks_queue = self .engine_worker_queue
150152 self .token_processor .run ()
151153 if self .cfg .scheduler_config .splitwise_role != "mixed" :
152- self .split_mode_get_tasks ()
154+ self ._process_splitwise_task ()
153155
154156 def create_data_processor (self ):
155157 self .input_processor = InputPreprocessor (
@@ -310,7 +312,7 @@ def start_worker_queue_service(self, start_queue):
310312 ),
311313 )
312314
313- def insert_tasks (self , tasks , current_id = - 1 , allocated = False ):
315+ def insert_tasks (self , tasks : Union [ List [ Request ], List [ RequestOutput ]] , current_id = - 1 , allocated = False ):
314316 """
315317 Insert tasks to engine.
316318 """
@@ -572,7 +574,7 @@ def update_mm_requests_chunk_size(self, requests):
572574 patch_st += chunk_patch_num
573575 request .set ("prefill_chunk_info" , chunks_info )
574576
575- def _insert_task_to_worker (self ):
577+ def _schedule_request_to_worker (self ):
576578 """
577579 Insert task to engine thread, monitor scheduler request queue.
578580 if the engine has resource, insert task to engine
@@ -618,7 +620,7 @@ def _insert_task_to_worker(self):
618620 time .sleep (0.001 )
619621 continue
620622 if self .cfg .splitwise_version == "v2" and self .cfg .scheduler_config .splitwise_role == "decode" :
621- # the task in decode instance will processed in split_mode_get_tasks thread
623+ # the task in decode instance will processed in _process_splitwise_task thread
622624 continue
623625
624626 llm_logger .debug (f"get tasks from scheduler: { tasks } " )
@@ -637,7 +639,7 @@ def _insert_task_to_worker(self):
637639 err_msg = f"Error happend while insert task to engine: { e } , { traceback .format_exc ()!s} ."
638640 self .llm_logger .error (err_msg )
639641
640- def _scheduler_task_to_worker_v1 (self ):
642+ def _schedule_request_to_worker_v1 (self ):
641643 """
642644 Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
643645 """
@@ -921,7 +923,7 @@ def _zmq_send_generated_tokens(self):
921923 except Exception as e :
922924 llm_logger .error (f"Unexcepted error happend: { e } , { traceback .format_exc ()!s} " )
923925
924- def split_mode_get_tasks (self ):
926+ def _process_splitwise_task (self ):
925927 """
926928 Processing tasks from engine worker queue in splitwise deployment.
927929 For v0 version, prefill instance gets tasks from engine worker queue.
@@ -932,10 +934,25 @@ def split_mode_get_tasks(self):
932934
933935 def receiver_loop ():
934936 waiting_resource_requests = []
937+ waiting_ready_tasks = []
938+
939+ # Waiting for the api_server and scheduler in decode to
940+ # receive the request sent by the client
941+ def _decode_process_prefilled_task_v0_scheduler (input_tasks ):
942+ ready_tasks = []
943+ waiting_tasks = []
944+ for task in input_tasks :
945+ if not hasattr (self .scheduler , "has_request" ) or self .scheduler .has_request (task .request_id ):
946+ ready_tasks .append (task )
947+ else :
948+ waiting_tasks .append (task )
949+ self .insert_tasks (ready_tasks , allocated = True )
950+ if self .cfg .splitwise_version in ("v0" , "v2" ):
951+ self .scheduler .put_results (ready_tasks )
952+ return waiting_tasks
935953
936954 while self .running :
937955 try :
938-
939956 processed_indices = []
940957 for idx , task in enumerate (waiting_resource_requests ):
941958 if envs .ENABLE_V1_KVCACHE_SCHEDULER :
@@ -958,19 +975,24 @@ def receiver_loop():
958975 for idx in sorted (processed_indices , reverse = True ):
959976 waiting_resource_requests .pop (idx )
960977
961- if not self .engine_worker_queue .disaggregate_queue_empty ():
978+ waiting_ready_tasks = _decode_process_prefilled_task_v0_scheduler (waiting_ready_tasks )
979+
980+ if self .engine_worker_queue .disaggregate_queue_empty ():
981+ time .sleep (0.001 )
982+ else :
962983 items = self .engine_worker_queue .get_disaggregated_tasks ()
963984 for item in items :
964985 role = item [0 ]
965986 tasks = item [1 ]
966987
967- if role == "prefill" : # prefill instance gets tasks from engine worker queue
988+ # prefill instance gets tasks from engine worker queue
989+ if role == "prefill" :
968990 for task in tasks :
969991 task .max_tokens = task .min_tokens = 2
970992 self .insert_tasks (tasks )
971-
972- elif role == "decode" : # decode instance gets tasks from engine worker queue
973- if hasattr (tasks [0 ], "finished" ):
993+ # decode instance gets tasks from engine worker queue
994+ elif role == "decode" :
995+ if isinstance (tasks [0 ], RequestOutput ):
974996 self .llm_logger .debug (f"receive prefilled tasks, { tasks } " )
975997 if not isinstance (tasks , list ):
976998 tasks = [tasks ]
@@ -1009,11 +1031,9 @@ def receiver_loop():
10091031 self .resource_manager .insert_task_for_decoding (task )
10101032
10111033 else :
1012- self .insert_tasks (tasks , allocated = True )
1013- if self .cfg .splitwise_version in ("v0" , "v2" ):
1014- self .scheduler .put_results (tasks )
1015- else :
1016- self .llm_logger .debug (f"receive tasks to allocate resource, { tasks } " )
1034+ waiting_ready_tasks .extend (_decode_process_prefilled_task_v0_scheduler (tasks ))
1035+ elif isinstance (tasks [0 ], Request ):
1036+ self .llm_logger .debug (f"receive tasks to preallocate resource, { tasks } " )
10171037 if len (waiting_resource_requests ):
10181038 self .llm_logger .info (f"Waiting for resource for task { tasks [0 ].request_id } " )
10191039 waiting_resource_requests .extend (tasks )
@@ -1044,9 +1064,8 @@ def receiver_loop():
10441064 self .llm_logger .info (
10451065 f"Added { len (new_waiting )} tasks to waiting queue"
10461066 )
1047-
1048- else :
1049- time .sleep (0.001 )
1067+ else :
1068+ raise ValueError (f"Unsupported task type: { type (tasks [0 ])} " )
10501069
10511070 except Exception as e :
10521071 self .llm_logger .error (f"Error in main loop: { e } " )
0 commit comments