Skip to content

Commit cc78a3a

Browse files
committed
add support torchax embedding
1 parent 2baaaaa commit cc78a3a

File tree

3 files changed

+44
-20
lines changed

3 files changed

+44
-20
lines changed

tpu_inference/models/jax/adapters.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import typing as tp
22

3+
import torch
34
import jax
45
from flax import nnx
6+
from flax.typing import PRNGKey
57
from jax.sharding import Mesh
68

79
from tpu_inference.layers.jax.pool.pooler import Pooler
@@ -84,3 +86,31 @@ def _init_pooler(self, vllm_config: VllmConfig) -> None:
8486
"ForEmbedding",
8587
)
8688
return ModelForEmbedding # type: ignore[return-value]
89+
90+
91+
92+
def init_pooler_from_vllm_model(
93+
vllm_model: torch.nn.Module,
94+
vllm_config: VllmConfig,
95+
rng_key: PRNGKey,
96+
mesh: Mesh,
97+
):
98+
class DummyModule:
99+
def __init__(self, vllm_config, rng_key, mesh):
100+
pass
101+
102+
for suffix in _GENERATE_SUFFIXES:
103+
if suffix in vllm_model.__class__.__name__:
104+
return None
105+
106+
if "ForEmbedding" in vllm_model.__class__.__name__:
107+
EmbedModel = as_embedding_model(DummyModule)
108+
109+
embed_model = EmbedModel(vllm_config=vllm_config, rng_key=rng_key, mesh=mesh,)
110+
embed_model._init_pooler(vllm_config)
111+
return embed_model.pooler
112+
else:
113+
raise NotImplementedError(
114+
f"Pooling initialization for {vllm_model.__class__.__name__} is not implemented."
115+
)
116+

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from tpu_inference.models.vllm.vllm_model_wrapper_context import (
3131
get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
3232
from tpu_inference.runner.lora_utils import replace_lora_metadata
33+
from tpu_inference.layers.jax.pool.pooler import Pooler
34+
from tpu_inference.models.jax.adapters import init_pooler_from_vllm_model
3335

3436
logger = init_logger(__name__)
3537

@@ -72,6 +74,7 @@ class VllmModelWrapper:
7274
rng: PRNGKey
7375
mesh: Mesh
7476
model: _VllmRunner
77+
pooler: Pooler
7578

7679
def __init__(self, vllm_config: VllmConfig, rng: PRNGKey, mesh: Mesh):
7780
self.vllm_config = vllm_config
@@ -137,6 +140,10 @@ def load_weights(self):
137140
self.model = _VllmRunner(vllm_model)
138141
params_and_buffers = shard_model_to_tpu(self.model, self.mesh)
139142

143+
144+
# TODO: need to seperate this params_and_buffer for pooler (some pooler is not stateless)
145+
self.pooler = init_pooler_from_vllm_model(vllm_model, self.vllm_config, self.rng, self.mesh)
146+
140147
# Returning to the jax land, so we need to wrap it into a JaxValue.
141148
return jax_view(params_and_buffers), lora_manager
142149

tpu_inference/runner/compilation_manager.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from tpu_inference.logger import init_logger
2323
from tpu_inference.utils import device_array
24+
from torchax.ops.mappings import t2j_dtype
2425

2526
if TYPE_CHECKING:
2627
from tpu_inference.runner.tpu_runner import TPUModelRunner
@@ -114,31 +115,17 @@ def _precompile_pooling(self) -> None:
114115

115116
for num_tokens in self.runner.num_tokens_paddings:
116117
hidden_states = self._create_dummy_tensor(
117-
(num_tokens, hidden_size), dtype, sharding=hidden_sharding)
118+
(num_tokens, hidden_size), t2j_dtype(dtype), sharding=hidden_sharding)
118119

119120
for num_reqs in self.runner.num_reqs_paddings:
120121
if num_reqs == 0 or num_reqs > num_tokens:
121122
continue
122123

123-
prompt_lens = np.ones(num_reqs, dtype=np.int32)
124-
first_token_indices = np.arange(num_reqs, dtype=np.int32)
125-
last_token_indices = first_token_indices.copy()
126-
normalize = np.ones(num_reqs, dtype=np.int8)
127-
128-
(
129-
prompt_lens,
130-
normalize,
131-
first_token_indices,
132-
last_token_indices,
133-
) = device_array(
134-
self.runner.mesh,
135-
(
136-
prompt_lens,
137-
normalize,
138-
first_token_indices,
139-
last_token_indices,
140-
),
141-
)
124+
prompt_lens = self._create_dummy_tensor(num_reqs, dtype = jnp.int32)
125+
first_token_indices = self._create_dummy_tensor(num_reqs, dtype = jnp.int32)
126+
last_token_indices = self._create_dummy_tensor(num_reqs, dtype = jnp.int32)
127+
normalize = self._create_dummy_tensor(num_reqs, dtype = jnp.int32)
128+
142129

143130
pooling_metadata = TPUSupportedPoolingMetadata(
144131
prompt_lens=prompt_lens,

0 commit comments

Comments
 (0)