Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion tpu_inference/models/jax/utils/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import jax
import jax.numpy as jnp
import torch
from flax import nnx
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
Expand Down Expand Up @@ -114,6 +115,13 @@ def get_model_weights_files(
return weights_files


def _convert_torch_type_to_jax_type(torch_dtype: torch.dtype):
dtype_map = {
torch.float8_e4m3fn: jnp.float8_e4m3fn,
}
return dtype_map.get(torch_dtype)


def model_weights_single_file_generator(
weights_file: str,
framework: str,
Expand All @@ -130,7 +138,20 @@ def model_weights_single_file_generator(
if filter_regex is not None and not re.match(
filter_regex, name):
continue
weight_tensor = f.get_tensor(name)
try:
weight_tensor = f.get_tensor(name)
# Flax numpy does not support float8 dtype, so we have to load it using torch in that case.
except Exception as e:
logger.warning(
f"Failed to load tensor '{name}' with framework '{framework}' due to error {e}. Trying load with framework=pt again."
)
with safe_open(weights_file, framework='pt') as pt_f:
weight_tensor = pt_f.get_tensor(name)
# convert the torch tensor to jax.
jax_dtype = _convert_torch_type_to_jax_type(
weight_tensor.dtype)
weight_tensor = weight_tensor.to(
torch.float32).numpy().astype(jax_dtype)
yield name, weight_tensor


Expand Down Expand Up @@ -294,6 +315,9 @@ def _load_hf_weights_on_thread(vllm_config,
)
hf_weight = hf_weight.astype(model_config.dtype)

# TODO: Handle weight_scale
if hf_key.endswith(".weight_scale"):
continue
if hf_key.endswith(".weight"):
hf_key = hf_key.removesuffix(".weight")

Expand Down