Skip to content

Commit 34cb3a0

Browse files
committed
simplify
1 parent 198d49f commit 34cb3a0

File tree

4 files changed

+31
-61
lines changed

4 files changed

+31
-61
lines changed

tpu_inference/layers/jax/pool/pooler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import jax
55
import jax.numpy as jnp
66
from flax import nnx
7-
from tpu_inference.layers.jax.pool.pooling_metadata import TPUSupportedPoolingMetadata
7+
from tpu_inference.layers.jax.pool.pooling_metadata import TPUSupportedPoolingMetadata, is_partial_prefill
88

99
from vllm.config.pooler import PoolerConfig
1010

@@ -212,7 +212,7 @@ def __call__(
212212
return self.head(pooled, pooling_metadata)
213213

214214
def get_supported_tasks(self) -> set[str]:
215-
return {"embed"}
215+
return ("embed",)
216216

217217

218218
def normalize(embeddings: jax.Array) -> jax.Array:

tpu_inference/layers/jax/pool/pooling_metadata.py

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,21 @@
1919

2020
def build_pooling_cursor(
2121
num_scheduled_tokens: list[int],
22-
padded_num_seqs: int,
23-
prompt_lens: jax.Array,
22+
padded_num_reqs: int,
2423
):
25-
assert len(prompt_lens) == len(num_scheduled_tokens)
2624

2725
n_seq = len(num_scheduled_tokens)
28-
num_scheduled_tokens_padded = jnp.zeros(padded_num_seqs)
29-
num_scheduled_tokens_padded = num_scheduled_tokens_padded.at[:n_seq].set(
26+
padded_num_scheduled_tokens = jnp.zeros(padded_num_reqs)
27+
padded_num_scheduled_tokens = padded_num_scheduled_tokens.at[:n_seq].set(
3028
jnp.asarray(num_scheduled_tokens, dtype=jnp.int32)
3129
)
32-
cumsum = jnp.cumsum(num_scheduled_tokens_padded, dtype = jnp.int64)
30+
cumsum = jnp.cumsum(padded_num_scheduled_tokens, dtype = jnp.int64)
3331
first_token_indices = jnp.concatenate((jnp.asarray((0,)), cumsum[:-1]))
34-
last_token_indices = (first_token_indices + num_scheduled_tokens_padded - 1).astype(jnp.int64)
32+
last_token_indices = (first_token_indices + padded_num_scheduled_tokens - 1).astype(jnp.int64)
3533
last_token_indices = jnp.where(
36-
num_scheduled_tokens_padded > 0, last_token_indices, first_token_indices
34+
padded_num_scheduled_tokens > 0, last_token_indices, first_token_indices
3735
)
38-
return first_token_indices, last_token_indices
36+
return first_token_indices, last_token_indices, padded_num_scheduled_tokens
3937

4038

4139
@functools.partial(
@@ -44,11 +42,9 @@ def build_pooling_cursor(
4442
"prompt_lens",
4543
"first_token_indices",
4644
"last_token_indices",
47-
"normalize",
48-
"num_reqs",
49-
"padded_num_reqs",
45+
"num_scheduled_tokens",
5046
),
51-
meta_fields=("task",),
47+
meta_fields = (),
5248
)
5349
@dataclass
5450
class TPUSupportedPoolingMetadata:
@@ -57,64 +53,42 @@ class TPUSupportedPoolingMetadata:
5753
prompt_lens: jax.Array
5854
first_token_indices: jax.Array
5955
last_token_indices: jax.Array
60-
normalize: jax.Array
61-
num_reqs: int
62-
padded_num_reqs: int
63-
task: str
56+
num_scheduled_tokens: jax.Array
6457

6558
@classmethod
6659
def from_input_batch(
6760
cls,
6861
mesh: Mesh,
6962
input_batch: InputBatch,
70-
num_scheduled_tokens: list[int],
63+
padded_num_scheduled_tokens: list[int],
7164
padded_num_reqs: int,
7265
) -> TPUSupportedPoolingMetadata:
7366
pooling_params_list = input_batch.get_pooling_params()
7467

7568
num_reqs = input_batch.num_reqs
7669
assert len(pooling_params_list) == num_reqs
70+
assert len(input_batch.num_prompt_tokens[:num_reqs]) == len(padded_num_scheduled_tokens)
7771

78-
padded_prompt_lens_np = np.zeros(padded_num_reqs, dtype=np.int32)
79-
padded_prompt_lens_np[:num_reqs] = input_batch.num_prompt_tokens[:num_reqs]
80-
81-
normalize = np.full(padded_num_reqs, -1, dtype=np.int8)
82-
83-
# Instead of shutting down the whole program, we should just ignore it and make it return 'embed' by default,
84-
# but provide a warning.
85-
for idx, params in enumerate(pooling_params_list):
86-
if params.normalize is True:
87-
normalize[idx] = 1
88-
elif params.normalize is False:
89-
normalize[idx] = 0
90-
91-
if (task := params.task) not in SUPPORTED_POOLING_TASKS:
92-
logger.warning(
93-
f"Unsupported pooling task '{task}'. Supported tasks: {sorted(SUPPORTED_POOLING_TASKS)}. Defaulting to 'embed'."
94-
)
95-
96-
# maybe in the future if we need to support multiple tasks in one batch, we need to make sure each batch has only one task
97-
# if not task_values:
98-
# raise ValueError("Pooling metadata requires at least one request")
99-
# if any(task != task_values[0] for task in task_values):
100-
# raise ValueError("Mixed pooling tasks within the same batch are not supported yet")
101-
102-
task = "embed"
103-
first_token_indices, last_token_indices = build_pooling_cursor(
104-
num_scheduled_tokens, padded_num_reqs, padded_prompt_lens_np[:num_reqs]
72+
padded_prompt_lens= jnp.zeros(padded_num_reqs, dtype=np.int32)
73+
padded_prompt_lens= padded_prompt_lens.at[:num_reqs].set(input_batch.num_prompt_tokens[:num_reqs])
74+
75+
first_token_indices, last_token_indices, padded_num_scheduled_tokens = build_pooling_cursor(
76+
padded_num_scheduled_tokens, padded_num_reqs
10577
)
10678

107-
prompt_lens, normalize, first_token_indices, last_token_indices = device_array(
79+
prompt_lens, first_token_indices, last_token_indices, num_scheduled_tokens = device_array(
10880
mesh,
109-
(padded_prompt_lens_np, normalize, first_token_indices, last_token_indices),
81+
(padded_prompt_lens, first_token_indices, last_token_indices, padded_num_scheduled_tokens),
11082
)
11183

84+
#everything in pooling_metadata is padded.
11285
return cls(
11386
prompt_lens=prompt_lens,
11487
first_token_indices=first_token_indices,
11588
last_token_indices=last_token_indices,
116-
normalize=normalize,
117-
task=task,
118-
num_reqs=num_reqs,
119-
padded_num_reqs=padded_num_reqs,
89+
num_scheduled_tokens = num_scheduled_tokens,
12090
)
91+
92+
93+
def is_partial_prefill(pooling_metadata: TPUSupportedPoolingMetadata):
94+
return not jnp.all(pooling_metadata.prompt_lens == pooling_metadata.num_scheduled_tokens)

tpu_inference/runner/compilation_manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,18 @@ def _precompile_pooling(self) -> None:
121121
if num_reqs == 0 or num_reqs > num_tokens:
122122
continue
123123

124+
# can we just use (one array here)
124125
prompt_lens = self._create_dummy_tensor(num_reqs, dtype = jnp.int32)
125126
first_token_indices = self._create_dummy_tensor(num_reqs, dtype = jnp.int32)
126127
last_token_indices = self._create_dummy_tensor(num_reqs, dtype = jnp.int32)
127-
normalize = self._create_dummy_tensor(num_reqs, dtype = jnp.int32)
128+
num_scheduled_tokens = self._create_dummy_tensor(num_reqs, dtype = jnp.int32)
128129

129130

130131
pooling_metadata = TPUSupportedPoolingMetadata(
131132
prompt_lens=prompt_lens,
132133
first_token_indices=first_token_indices,
133134
last_token_indices=last_token_indices,
134-
normalize=normalize,
135-
num_reqs=num_reqs,
136-
padded_num_reqs=num_reqs,
137-
task="embed",
135+
num_scheduled_tokens = num_scheduled_tokens,
138136
)
139137

140138
self._run_compilation(

tpu_inference/runner/tpu_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,6 @@ def load_model(self):
505505
if self.is_pooling_model:
506506
self.pooler = self.model.pooler
507507

508-
print(f"DEBUGPRINT[96]: tpu_jax_runner.py:396: self.is_pooling_model={self.is_pooling_model}")
509508
self.precompile_vision_encoder_fn = multimodal_fns.get(
510509
"precompile_vision_encoder_fn", None)
511510
self.get_multimodal_embeddings_fn = multimodal_fns.get(
@@ -530,7 +529,7 @@ def load_model(self):
530529

531530
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
532531
if self.is_pooling_model:
533-
return ("embed", )
532+
return self.pooler.get_supported_tasks()
534533
return ("generate", )
535534

536535
def get_kv_cache_spec(self):
@@ -780,7 +779,6 @@ def _execute_model(
780779
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
781780

782781

783-
784782
pooler_output = []
785783
for raw_output, seq_len, prompt_len in zip(raw_pooler_output, seq_lens_cpu, prompt_lens):
786784
output = raw_output if seq_len == prompt_len else None

0 commit comments

Comments
 (0)