1919
2020def build_pooling_cursor (
2121 num_scheduled_tokens : list [int ],
22- padded_num_seqs : int ,
23- prompt_lens : jax .Array ,
22+ padded_num_reqs : int ,
2423):
25- assert len (prompt_lens ) == len (num_scheduled_tokens )
2624
2725 n_seq = len (num_scheduled_tokens )
28- num_scheduled_tokens_padded = jnp .zeros (padded_num_seqs )
29- num_scheduled_tokens_padded = num_scheduled_tokens_padded .at [:n_seq ].set (
26+ padded_num_scheduled_tokens = jnp .zeros (padded_num_reqs )
27+ padded_num_scheduled_tokens = padded_num_scheduled_tokens .at [:n_seq ].set (
3028 jnp .asarray (num_scheduled_tokens , dtype = jnp .int32 )
3129 )
32- cumsum = jnp .cumsum (num_scheduled_tokens_padded , dtype = jnp .int64 )
30+ cumsum = jnp .cumsum (padded_num_scheduled_tokens , dtype = jnp .int64 )
3331 first_token_indices = jnp .concatenate ((jnp .asarray ((0 ,)), cumsum [:- 1 ]))
34- last_token_indices = (first_token_indices + num_scheduled_tokens_padded - 1 ).astype (jnp .int64 )
32+ last_token_indices = (first_token_indices + padded_num_scheduled_tokens - 1 ).astype (jnp .int64 )
3533 last_token_indices = jnp .where (
36- num_scheduled_tokens_padded > 0 , last_token_indices , first_token_indices
34+ padded_num_scheduled_tokens > 0 , last_token_indices , first_token_indices
3735 )
38- return first_token_indices , last_token_indices
36+ return first_token_indices , last_token_indices , padded_num_scheduled_tokens
3937
4038
4139@functools .partial (
@@ -44,11 +42,9 @@ def build_pooling_cursor(
4442 "prompt_lens" ,
4543 "first_token_indices" ,
4644 "last_token_indices" ,
47- "normalize" ,
48- "num_reqs" ,
49- "padded_num_reqs" ,
45+ "num_scheduled_tokens" ,
5046 ),
51- meta_fields = ( "task" , ),
47+ meta_fields = ( ),
5248)
5349@dataclass
5450class TPUSupportedPoolingMetadata :
@@ -57,64 +53,42 @@ class TPUSupportedPoolingMetadata:
5753 prompt_lens : jax .Array
5854 first_token_indices : jax .Array
5955 last_token_indices : jax .Array
60- normalize : jax .Array
61- num_reqs : int
62- padded_num_reqs : int
63- task : str
56+ num_scheduled_tokens : jax .Array
6457
6558 @classmethod
6659 def from_input_batch (
6760 cls ,
6861 mesh : Mesh ,
6962 input_batch : InputBatch ,
70- num_scheduled_tokens : list [int ],
63+ padded_num_scheduled_tokens : list [int ],
7164 padded_num_reqs : int ,
7265 ) -> TPUSupportedPoolingMetadata :
7366 pooling_params_list = input_batch .get_pooling_params ()
7467
7568 num_reqs = input_batch .num_reqs
7669 assert len (pooling_params_list ) == num_reqs
70+ assert len (input_batch .num_prompt_tokens [:num_reqs ]) == len (padded_num_scheduled_tokens )
7771
78- padded_prompt_lens_np = np .zeros (padded_num_reqs , dtype = np .int32 )
79- padded_prompt_lens_np [:num_reqs ] = input_batch .num_prompt_tokens [:num_reqs ]
80-
81- normalize = np .full (padded_num_reqs , - 1 , dtype = np .int8 )
82-
83- # Instead of shutting down the whole program, we should just ignore it and make it return 'embed' by default,
84- # but provide a warning.
85- for idx , params in enumerate (pooling_params_list ):
86- if params .normalize is True :
87- normalize [idx ] = 1
88- elif params .normalize is False :
89- normalize [idx ] = 0
90-
91- if (task := params .task ) not in SUPPORTED_POOLING_TASKS :
92- logger .warning (
93- f"Unsupported pooling task '{ task } '. Supported tasks: { sorted (SUPPORTED_POOLING_TASKS )} . Defaulting to 'embed'."
94- )
95-
96- # maybe in the future if we need to support multiple tasks in one batch, we need to make sure each batch has only one task
97- # if not task_values:
98- # raise ValueError("Pooling metadata requires at least one request")
99- # if any(task != task_values[0] for task in task_values):
100- # raise ValueError("Mixed pooling tasks within the same batch are not supported yet")
101-
102- task = "embed"
103- first_token_indices , last_token_indices = build_pooling_cursor (
104- num_scheduled_tokens , padded_num_reqs , padded_prompt_lens_np [:num_reqs ]
72+ padded_prompt_lens = jnp .zeros (padded_num_reqs , dtype = np .int32 )
73+ padded_prompt_lens = padded_prompt_lens .at [:num_reqs ].set (input_batch .num_prompt_tokens [:num_reqs ])
74+
75+ first_token_indices , last_token_indices , padded_num_scheduled_tokens = build_pooling_cursor (
76+ padded_num_scheduled_tokens , padded_num_reqs
10577 )
10678
107- prompt_lens , normalize , first_token_indices , last_token_indices = device_array (
79+ prompt_lens , first_token_indices , last_token_indices , num_scheduled_tokens = device_array (
10880 mesh ,
109- (padded_prompt_lens_np , normalize , first_token_indices , last_token_indices ),
81+ (padded_prompt_lens , first_token_indices , last_token_indices , padded_num_scheduled_tokens ),
11082 )
11183
84+ #everything in pooling_metadata is padded.
11285 return cls (
11386 prompt_lens = prompt_lens ,
11487 first_token_indices = first_token_indices ,
11588 last_token_indices = last_token_indices ,
116- normalize = normalize ,
117- task = task ,
118- num_reqs = num_reqs ,
119- padded_num_reqs = padded_num_reqs ,
89+ num_scheduled_tokens = num_scheduled_tokens ,
12090 )
91+
92+
93+ def is_partial_prefill (pooling_metadata : TPUSupportedPoolingMetadata ):
94+ return not jnp .all (pooling_metadata .prompt_lens == pooling_metadata .num_scheduled_tokens )
0 commit comments