Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions axlearn/common/array_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from jax._src import array, typing
from jax._src.layout import Layout
from jax.experimental.array_serialization import serialization
from packaging import version

from axlearn.common.utils import Tensor

Expand Down Expand Up @@ -75,6 +76,7 @@ def shard_coordinate(self):

# Tuple (and thus hashable) representation of a slice object (start, end, step).
_SliceTuple = tuple[Optional[int], Optional[int], Optional[int]]
JAX_VERSION = version.parse(jax.__version__)


def _slices_to_tuple(slices: list[slice]) -> tuple[_SliceTuple, ...]:
Expand Down Expand Up @@ -306,10 +308,13 @@ async def _async_serialize(
and arr_inp.is_fully_addressable
)
# pylint: disable=protected-access
spec_has_metadata = {
"0.6.2": lambda: serialization.ts_impl._spec_has_metadata,
"0.5.3": lambda: serialization._spec_has_metadata,
}[jax.__version__]()
if JAX_VERSION >= version.parse("0.6.2"):
spec_has_metadata = serialization.ts_impl._spec_has_metadata
elif JAX_VERSION >= version.parse("0.5.3"):
spec_has_metadata = serialization._spec_has_metadata
else:
raise ValueError(f"Unsupported JAX version for spec_has_metadata: {jax.__version__}")

if not spec_has_metadata(tensorstore_spec):
# pylint: disable-next=protected-access
tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp)
Expand Down Expand Up @@ -486,10 +491,14 @@ async def _async_deserialize(
async def cb(index: array.Index, device: jax.Device):
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
estimate_read_memory_footprint = {
"0.6.2": lambda: serialization.ts_impl.estimate_read_memory_footprint,
"0.5.3": lambda: serialization.estimate_read_memory_footprint,
}[jax.__version__]()
if JAX_VERSION >= version.parse("0.6.2"):
estimate_read_memory_footprint = serialization.ts_impl.estimate_read_memory_footprint
elif JAX_VERSION >= version.parse("0.5.3"):
estimate_read_memory_footprint = serialization.estimate_read_memory_foot_print
else:
raise ValueError(
f"Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer"
)
requested_bytes = estimate_read_memory_footprint(t, restricted_domain)
# Limit the bytes read for every shard.
await byte_limiter.wait_for_bytes(requested_bytes)
Expand Down Expand Up @@ -567,10 +576,13 @@ async def cb(index: array.Index, device: jax.Device):
return result

# pylint: disable=protected-access
create_async_array_from_callback = {
"0.6.2": lambda: serialization.ts_impl._create_async_array_from_callback,
"0.5.3": lambda: serialization.create_async_array_from_callback,
}[jax.__version__]()
if JAX_VERSION >= version.parse("0.6.2"):
create_async_array_from_callback = serialization.ts_impl._create_async_array_from_callback
elif JAX_VERSION >= version.parse("0.5.3"):
create_async_array_from_callback = serialization.create_async_array_from_callback
else:
raise ValueError("Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer.")

return await create_async_array_from_callback(shape, in_sharding, cb)


Expand Down Expand Up @@ -652,10 +664,14 @@ def serialize(

commit_futures = [[] for _ in range(len(tensorstore_specs))]

async_serialize = {
"0.6.2": lambda: serialization.ts_impl.async_serialize,
"0.5.3": lambda: serialization.async_serialize,
}[jax.__version__]()
if JAX_VERSION >= version.parse("0.6.2"):
async_serialize = serialization.ts_impl.async_serialize
elif JAX_VERSION >= version.parse("0.5.3"):
async_serialize = serialization.async_serialize
else:
raise ValueError(
f"Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer."
)

# pylint: disable-next=redefined-outer-name
async def _run_serializer():
Expand Down
10 changes: 8 additions & 2 deletions axlearn/common/flash_attention/gpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
)
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import pallas as pl
from jax.experimental.pallas.triton import TritonCompilerParams

from axlearn.common.attention_bias import (
NEG_INF,
Expand All @@ -69,7 +68,14 @@
from axlearn.common.kv_cache.base_kv_cache import BaseKVCache
from axlearn.common.kv_cache.kv_cache import KVCache
from axlearn.common.layers import get_dropout_mask
from axlearn.common.utils import Nested, Tensor
from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Nested, Tensor

# pylint: disable=ungrouped-imports
if _JAX_MEMORY_SPACE_SUPPORT:
from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams
else:
from jax.experimental.pallas.triton import TritonCompilerParams
# pylint: disable=ungrouped-imports


def _segment_mask(
Expand Down
10 changes: 8 additions & 2 deletions axlearn/common/flash_attention/gpu_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from jax import lax
from jax._src.cudnn.fused_attention_stablehlo import check_compute_capability
from jax.experimental import pallas as pl
from jax.experimental.pallas.triton import TritonCompilerParams

from axlearn.common.attention_bias import (
NEG_INF,
Expand All @@ -61,7 +60,14 @@
from axlearn.common.flash_attention.common import BaseSingleStepDecoding, get_gpu_dot_precision
from axlearn.common.kv_cache.base_kv_cache import BaseKVCache
from axlearn.common.kv_cache.kv_cache import KVCache
from axlearn.common.utils import Nested, Tensor
from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Nested, Tensor

# pylint: disable=ungrouped-imports
if _JAX_MEMORY_SPACE_SUPPORT:
from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams
else:
from jax.experimental.pallas.triton import TritonCompilerParams
# pylint: enable=ungrouped-imports


# Note: split_k_seq_len must be a multiple of block_k.
Expand Down
12 changes: 10 additions & 2 deletions axlearn/common/flash_attention/gpu_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import jax.numpy as jnp
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas.triton import TritonCompilerParams

from axlearn.common.attention_bias import (
NEG_INF,
Expand All @@ -31,7 +30,16 @@
from axlearn.common.flash_attention.common import BasePagedAttention, get_gpu_dot_precision
from axlearn.common.flash_attention.gpu_decoding import _get_sm_count as get_sm_count
from axlearn.common.kv_cache.base_kv_cache import BaseKVCache
from axlearn.common.utils import Nested, Tensor
from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Nested, Tensor

# pylint: disable=ungrouped-imports
if _JAX_MEMORY_SPACE_SUPPORT:
from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams
else:
from jax.experimental.pallas.triton import ( # isort: skip
TritonCompilerParams,
)
# pylint: enable=ungrouped-imports


def _paged_attention_kernel(
Expand Down
10 changes: 8 additions & 2 deletions axlearn/common/kv_cache/paged_kv_cache_gpu_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas.triton import TritonCompilerParams

from axlearn.common.utils import Tensor
from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Tensor

# pylint: disable=ungrouped-imports
if _JAX_MEMORY_SPACE_SUPPORT:
from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams
else:
from jax.experimental.pallas.triton import TritonCompilerParams
# pylint: enable=ungrouped-imports


def _scatter_pages_kernel(
Expand Down
12 changes: 6 additions & 6 deletions axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import typing_extensions
from absl import logging
from jax import numpy as jnp
from jax._src.sharding_impls import TransferToMemoryKind
from optax._src import numerics

from axlearn.common import flax_struct, schedule
Expand All @@ -53,6 +52,8 @@
TransformPartitionSpecFn,
)
from axlearn.common.utils import (
DEVICE_MEMORY,
HOST_MEMORY,
MemoryKind,
Nested,
NestedTensor,
Expand All @@ -62,6 +63,7 @@
expand_vdicts,
flatten_items,
register_per_param_settings,
transfer_to_memory_kind,
tree_paths,
vectorized_tree_map,
)
Expand Down Expand Up @@ -2072,8 +2074,8 @@ def offload_optimizer(
optimizer: ConfigOr[PartitionedGradientTransformation],
*,
pattern: Union[str, re.Pattern] = ".*",
offload_src: MemoryKind = "device",
offload_dst: MemoryKind = "pinned_host",
offload_src: MemoryKind = DEVICE_MEMORY,
offload_dst: MemoryKind = HOST_MEMORY,
) -> PartitionedGradientTransformation:
"""Offload the state of the wrapped optimizer that matches `pattern` to `offload_dst`.

Expand Down Expand Up @@ -2145,9 +2147,7 @@ def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState:
# released, so we have less memory pressure at that point in time.
return jax.tree.map(
lambda path, tensor: (
jax.device_put(tensor, TransferToMemoryKind(dst))
if re.fullmatch(pattern, path)
else tensor
transfer_to_memory_kind(tensor, dst) if re.fullmatch(pattern, path) else tensor
),
tree_paths(state),
state,
Expand Down
16 changes: 12 additions & 4 deletions axlearn/common/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from axlearn.common.schedule import Schedule, adafactor_decay_rate, decay_bias_correction
from axlearn.common.test_utils import TestCase, assert_allclose
from axlearn.common.utils import (
_JAX_MEMORY_SPACE_SUPPORT,
NestedPartitionSpec,
PartitionSpec,
Tensor,
Expand Down Expand Up @@ -427,10 +428,17 @@ def compute_loss(x):
return loss, compute_loss(updated_params)

if offload:
self.assertIn(
"TransferToMemoryKind(memory_kind='pinned_host')",
str(jax.make_jaxpr(jit_fn)(params, state)),
)
jaxpr_str = str(jax.make_jaxpr(jit_fn)(params, state))
if _JAX_MEMORY_SPACE_SUPPORT:
self.assertIn(
"memory_kind=host",
jaxpr_str,
)
else:
self.assertIn(
"TransferToMemoryKind(memory_kind='pinned_host')",
str(jax.make_jaxpr(jit_fn)(params, state)),
)
loss, new_loss = jit_fn(params, state)
self.assertLess(new_loss, loss)

Expand Down
22 changes: 21 additions & 1 deletion axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from jax.experimental import mesh_utils, multihost_utils
from jax.extend.core import Primitive
from jax.sharding import PartitionSpec
from packaging import version

from axlearn.common import serialization
from axlearn.common.config import (
Expand All @@ -66,6 +67,9 @@
register_validator,
)

# Define the version of JAX for compatibility on MemKind
_JAX_MEMORY_SPACE_SUPPORT = version.parse(jax.__version__) >= version.parse("0.7.0")

# New code should use Nested[XX] instead of NestedXX.
# Old definitions are provided for backwards compatibility.
_NestedT = TypeVar("_NestedT")
Expand Down Expand Up @@ -118,7 +122,23 @@ def __len__(self):
# "pinned_host" = Page locked memory on CPU, which can be address directly by accelerators by
# direct memory access (DMA). For TPU, "pinned_host" memory layout follows TPU device tile
# layout and usually cannot be zero-copy converted to a CPU-tensor.
MemoryKind = Literal["device", "pinned_host"]
if _JAX_MEMORY_SPACE_SUPPORT:
MemoryKind = [jax.memory.Space.Device, jax.memory.Space.Host]
DEVICE_MEMORY = jax.memory.Space.Device
HOST_MEMORY = jax.memory.Space.Host

def transfer_to_memory_kind(tensor: Tensor, memory_kind: MemoryKind) -> Tensor:
return jax.device_put(tensor, memory_kind)

else:
from jax._src.sharding_impls import TransferToMemoryKind # pylint: disable=ungrouped-imports

MemoryKind = Literal["device", "pinned_host"]
DEVICE_MEMORY = "device"
HOST_MEMORY = "pinned_host"

def transfer_to_memory_kind(tensor: Tensor, memory_kind: MemoryKind) -> Tensor:
return jax.device_put(tensor, TransferToMemoryKind(memory_kind))


@dataclasses.dataclass
Expand Down
1 change: 0 additions & 1 deletion axlearn/experiments/text/gpt/c4_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,4 @@ def named_trainer_configs() -> dict[str, TrainerConfigFn]:
config_map = {}
config_map.update(fuji.trainer_configs(_train_input_source, _eval_input_sources))
config_map.update(gspmd.trainer_configs(_train_input_source, _eval_input_sources))
config_map.update(envy.trainer_configs(_train_input_source, _eval_input_sources))
return config_map
Loading