diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index b3a4a7de3..fe93e58fb 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -4,7 +4,6 @@ import jax.numpy as jnp import vllm.envs as vllm_envs -from torchax.ops.mappings import j2t_dtype from tpu_info import device from vllm.inputs import ProcessorInputs, PromptType from vllm.platforms.interface import Platform, PlatformEnum @@ -13,6 +12,7 @@ from tpu_inference import envs from tpu_inference.layers.common.sharding import ShardingConfigManager from tpu_inference.logger import init_logger +from tpu_inference.utils import to_jax_dtype, to_torch_dtype if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend @@ -150,18 +150,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # For mm model preprocessors, it may need the output dtype to be torch. # In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm": - if not isinstance(vllm_config.model_config.dtype, str): - logger.warning( - "The model dtype is not properly set for JAX backend. " - "Overwriting it to jnp.bfloat16") - vllm_config.model_config.dtype = jnp.bfloat16 + dtype = vllm_config.model_config.dtype + if impl == "vllm": + vllm_config.model_config.dtype = to_torch_dtype(dtype) else: - vllm_config.model_config.dtype = _DTYPE.get( - vllm_config.model_config.dtype, jnp.bfloat16) - - if impl == "vllm": - vllm_config.model_config.dtype = j2t_dtype( - vllm_config.model_config.dtype.dtype) + vllm_config.model_config.dtype = to_jax_dtype(dtype) # TODO(cuiq): remove this dependency. from vllm.v1.attention.backends.pallas import PallasAttentionBackend diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 3841e7460..6ed197ede 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -10,12 +10,11 @@ import jax.numpy as jnp import jaxtyping import numpy as np -import torch import vllm.envs as envs from flax import nnx from jax.experimental import mesh_utils from jax.sharding import NamedSharding, PartitionSpec -from torchax.ops.mappings import j2t, j2t_dtype +from torchax.ops.mappings import j2t from vllm.config import VllmConfig from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -64,7 +63,7 @@ StructuredDecodingManager from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer from tpu_inference.utils import (device_array, make_optimized_mesh, - time_function) + time_function, to_torch_dtype) logger = init_logger(__name__) @@ -78,17 +77,6 @@ request_distribution=[0, 0, 0], ) -TPU_STR_DTYPE_TO_TORCH_DTYPE = { - "half": torch.half, - "bfloat16": torch.bfloat16, - "float": torch.float, - "fp8": torch.float8_e4m3fn, - "fp8_e4m3": torch.float8_e4m3fn, - "fp8_e5m2": torch.float8_e5m2, - "int8": torch.int8, - "uint8": torch.uint8, -} - class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput): """Holds asynchronous model output specifically from a TPU runner. @@ -250,22 +238,10 @@ def __init__( self.uses_mrope, self.model_config) self.lora_utils = LoraUtils(self) - cache_config = self.cache_config - if cache_config.cache_dtype == "auto": - model_dtype = self.dtype - if isinstance(model_dtype, str): - self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype] - elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype): - self.kv_cache_dtype = j2t_dtype(model_dtype.dtype) - elif isinstance(model_dtype, torch.dtype): - self.kv_cache_dtype = model_dtype - else: - raise ValueError( - "KV cache is unsupported for model_dtype of %s", - model_dtype) - else: - self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] + cache_dtype = self.cache_config.cache_dtype + if cache_dtype == "auto": + cache_dtype = self.dtype + self.kv_cache_dtype = to_torch_dtype(cache_dtype) self._pre_async_results: AsyncPreResults | None = None self._substitute_placeholder_token_fn = _substitute_placeholder_token diff --git a/tpu_inference/utils.py b/tpu_inference/utils.py index ca3d693da..a1dd7b62d 100644 --- a/tpu_inference/utils.py +++ b/tpu_inference/utils.py @@ -8,11 +8,14 @@ import jax import jax.numpy as jnp import numpy as np +import torch from jax._src import dtypes from jax._src import mesh as mesh_lib from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.numpy.scalar_types import _ScalarMeta from jax.sharding import Mesh, NamedSharding, PartitionSpec +from torchax.ops.mappings import j2t_dtype, t2j_dtype from vllm import envs as vllm_envs from vllm import utils @@ -26,13 +29,23 @@ # This is used to translate from a string name for a dtype # to formal jax.numpy DType. One use case for this is # converting the `--kv_cache_dtype` flag to a dtype. -TPU_STR_DTYPE_TO_JAX_DTYPE = { - "bfloat16": jnp.bfloat16, - "fp8": jnp.float8_e4m3fn, - "fp8_e4m3": jnp.float8_e4m3, - "fp8_e5m2": jnp.float8_e5m2, - "int8": jnp.int8, -} + + +def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype): + if isinstance(dtype, str): + return jnp.dtype(dtype) + elif isinstance(dtype, torch.dtype): + return t2j_dtype(dtype) + elif isinstance(dtype, jnp.dtype): + return dtype + elif isinstance(dtype, _ScalarMeta): + return dtype.dtype + + +def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype): + dtype = to_jax_dtype(dtype) + return j2t_dtype(dtype) + _megacore = False logger = init_logger(__name__) @@ -295,8 +308,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype: Returns: jnp.dtype: The JAX dtype. """ - str_dtype = str_dtype.lower().strip() - return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype) + # TODO(kyuyeunk): Replace all reference of this function into TpuDtype. + return to_jax_dtype(str_dtype) def time_function(func):