Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import jax.numpy as jnp
import vllm.envs as vllm_envs
from torchax.ops.mappings import j2t_dtype
from tpu_info import device
from vllm.inputs import ProcessorInputs, PromptType
from vllm.platforms.interface import Platform, PlatformEnum
Expand All @@ -13,6 +12,7 @@
from tpu_inference import envs
from tpu_inference.layers.common.sharding import ShardingConfigManager
from tpu_inference.logger import init_logger
from tpu_inference.utils import to_jax_dtype, to_torch_dtype

if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
Expand Down Expand Up @@ -150,18 +150,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# For mm model preprocessors, it may need the output dtype to be torch.
# In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
if not isinstance(vllm_config.model_config.dtype, str):
logger.warning(
"The model dtype is not properly set for JAX backend. "
"Overwriting it to jnp.bfloat16")
vllm_config.model_config.dtype = jnp.bfloat16
dtype = vllm_config.model_config.dtype
if impl == "vllm":
vllm_config.model_config.dtype = to_torch_dtype(dtype)
else:
vllm_config.model_config.dtype = _DTYPE.get(
vllm_config.model_config.dtype, jnp.bfloat16)

if impl == "vllm":
vllm_config.model_config.dtype = j2t_dtype(
vllm_config.model_config.dtype.dtype)
vllm_config.model_config.dtype = to_jax_dtype(dtype)

# TODO(cuiq): remove this dependency.
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
Expand Down
36 changes: 6 additions & 30 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
import jax.numpy as jnp
import jaxtyping
import numpy as np
import torch
import vllm.envs as envs
from flax import nnx
from jax.experimental import mesh_utils
from jax.sharding import NamedSharding, PartitionSpec
from torchax.ops.mappings import j2t, j2t_dtype
from torchax.ops.mappings import j2t
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
Expand Down Expand Up @@ -64,7 +63,7 @@
StructuredDecodingManager
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
from tpu_inference.utils import (device_array, make_optimized_mesh,
time_function)
time_function, to_torch_dtype)

logger = init_logger(__name__)

Expand All @@ -78,17 +77,6 @@
request_distribution=[0, 0, 0],
)

TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.float8_e4m3fn,
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
"int8": torch.int8,
"uint8": torch.uint8,
}


class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
"""Holds asynchronous model output specifically from a TPU runner.
Expand Down Expand Up @@ -250,22 +238,10 @@ def __init__(
self.uses_mrope, self.model_config)
self.lora_utils = LoraUtils(self)

cache_config = self.cache_config
if cache_config.cache_dtype == "auto":
model_dtype = self.dtype
if isinstance(model_dtype, str):
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
elif isinstance(model_dtype, torch.dtype):
self.kv_cache_dtype = model_dtype
else:
raise ValueError(
"KV cache is unsupported for model_dtype of %s",
model_dtype)
else:
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
cache_dtype = self.cache_config.cache_dtype
if cache_dtype == "auto":
cache_dtype = self.dtype
self.kv_cache_dtype = to_torch_dtype(cache_dtype)

self._pre_async_results: AsyncPreResults | None = None
self._substitute_placeholder_token_fn = _substitute_placeholder_token
Expand Down
31 changes: 22 additions & 9 deletions tpu_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
import jax
import jax.numpy as jnp
import numpy as np
import torch
from jax._src import dtypes
from jax._src import mesh as mesh_lib
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.numpy.scalar_types import _ScalarMeta
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from torchax.ops.mappings import j2t_dtype, t2j_dtype
from vllm import envs as vllm_envs
from vllm import utils

Expand All @@ -26,13 +29,23 @@
# This is used to translate from a string name for a dtype
# to formal jax.numpy DType. One use case for this is
# converting the `--kv_cache_dtype` flag to a dtype.
TPU_STR_DTYPE_TO_JAX_DTYPE = {
"bfloat16": jnp.bfloat16,
"fp8": jnp.float8_e4m3fn,
"fp8_e4m3": jnp.float8_e4m3,
"fp8_e5m2": jnp.float8_e5m2,
"int8": jnp.int8,
}


def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype):
if isinstance(dtype, str):
return jnp.dtype(dtype)
elif isinstance(dtype, torch.dtype):
return t2j_dtype(dtype)
elif isinstance(dtype, jnp.dtype):
return dtype
elif isinstance(dtype, _ScalarMeta):
return dtype.dtype


def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype):
dtype = to_jax_dtype(dtype)
return j2t_dtype(dtype)


_megacore = False
logger = init_logger(__name__)
Expand Down Expand Up @@ -295,8 +308,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
Returns:
jnp.dtype: The JAX dtype.
"""
str_dtype = str_dtype.lower().strip()
return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
# TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
return to_jax_dtype(str_dtype)


def time_function(func):
Expand Down