diff --git a/tpu_inference/models/jax/utils/weight_utils.py b/tpu_inference/models/jax/utils/weight_utils.py index 4209af473..292137269 100644 --- a/tpu_inference/models/jax/utils/weight_utils.py +++ b/tpu_inference/models/jax/utils/weight_utils.py @@ -12,6 +12,7 @@ import jax import jax.numpy as jnp +import torch from flax import nnx from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P @@ -114,6 +115,13 @@ def get_model_weights_files( return weights_files +def _convert_torch_type_to_jax_type(torch_dtype: torch.dtype): + dtype_map = { + torch.float8_e4m3fn: jnp.float8_e4m3fn, + } + return dtype_map.get(torch_dtype) + + def model_weights_single_file_generator( weights_file: str, framework: str, @@ -130,7 +138,20 @@ def model_weights_single_file_generator( if filter_regex is not None and not re.match( filter_regex, name): continue - weight_tensor = f.get_tensor(name) + try: + weight_tensor = f.get_tensor(name) + # Flax numpy does not support float8 dtype, so we have to load it using torch in that case. + except Exception as e: + logger.warning( + f"Failed to load tensor '{name}' with framework '{framework}' due to error {e}. Trying load with framework=pt again." + ) + with safe_open(weights_file, framework='pt') as pt_f: + weight_tensor = pt_f.get_tensor(name) + # convert the torch tensor to jax. + jax_dtype = _convert_torch_type_to_jax_type( + weight_tensor.dtype) + weight_tensor = weight_tensor.to( + torch.float32).numpy().astype(jax_dtype) yield name, weight_tensor @@ -294,6 +315,9 @@ def _load_hf_weights_on_thread(vllm_config, ) hf_weight = hf_weight.astype(model_config.dtype) + # TODO: Handle weight_scale + if hf_key.endswith(".weight_scale"): + continue if hf_key.endswith(".weight"): hf_key = hf_key.removesuffix(".weight")