From 45345b10a54db49a834201c20bbbe9cd94cb5f08 Mon Sep 17 00:00:00 2001 From: Steboss Date: Sun, 7 Sep 2025 22:52:44 +0100 Subject: [PATCH 1/6] Fix memorykind error --- axlearn/common/optimizers.py | 12 ++++++------ axlearn/common/optimizers_test.py | 16 ++++++++++++---- axlearn/common/utils.py | 22 +++++++++++++++++++++- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index 54e4f0b16..c5c1b7be5 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -37,7 +37,6 @@ import typing_extensions from absl import logging from jax import numpy as jnp -from jax._src.sharding_impls import TransferToMemoryKind from optax._src import numerics from axlearn.common import flax_struct, schedule @@ -53,6 +52,8 @@ TransformPartitionSpecFn, ) from axlearn.common.utils import ( + DEVICE_MEMORY, + HOST_MEMORY, MemoryKind, Nested, NestedTensor, @@ -62,6 +63,7 @@ expand_vdicts, flatten_items, register_per_param_settings, + transfer_to_memory_kind, tree_paths, vectorized_tree_map, ) @@ -2072,8 +2074,8 @@ def offload_optimizer( optimizer: ConfigOr[PartitionedGradientTransformation], *, pattern: Union[str, re.Pattern] = ".*", - offload_src: MemoryKind = "device", - offload_dst: MemoryKind = "pinned_host", + offload_src: MemoryKind = DEVICE_MEMORY, + offload_dst: MemoryKind = HOST_MEMORY, ) -> PartitionedGradientTransformation: """Offload the state of the wrapped optimizer that matches `pattern` to `offload_dst`. @@ -2145,9 +2147,7 @@ def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState: # released, so we have less memory pressure at that point in time. return jax.tree.map( lambda path, tensor: ( - jax.device_put(tensor, TransferToMemoryKind(dst)) - if re.fullmatch(pattern, path) - else tensor + transfer_to_memory_kind(tensor, dst) if re.fullmatch(pattern, path) else tensor ), tree_paths(state), state, diff --git a/axlearn/common/optimizers_test.py b/axlearn/common/optimizers_test.py index 6fe408239..761ddee49 100644 --- a/axlearn/common/optimizers_test.py +++ b/axlearn/common/optimizers_test.py @@ -60,6 +60,7 @@ from axlearn.common.schedule import Schedule, adafactor_decay_rate, decay_bias_correction from axlearn.common.test_utils import TestCase, assert_allclose from axlearn.common.utils import ( + _JAX_MEMORY_SPACE_SUPPORT, NestedPartitionSpec, PartitionSpec, Tensor, @@ -427,10 +428,17 @@ def compute_loss(x): return loss, compute_loss(updated_params) if offload: - self.assertIn( - "TransferToMemoryKind(memory_kind='pinned_host')", - str(jax.make_jaxpr(jit_fn)(params, state)), - ) + jaxpr_str = str(jax.make_jaxpr(jit_fn)(params, state)) + if _JAX_MEMORY_SPACE_SUPPORT: + self.assertIn( + "memory_kind=host", + jaxpr_str, + ) + else: + self.assertIn( + "TransferToMemoryKind(memory_kind='pinned_host')", + str(jax.make_jaxpr(jit_fn)(params, state)), + ) loss, new_loss = jit_fn(params, state) self.assertLess(new_loss, loss) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 16a5fe3b4..776340a81 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -54,6 +54,7 @@ from jax.experimental import mesh_utils, multihost_utils from jax.extend.core import Primitive from jax.sharding import PartitionSpec +from packaging import version from axlearn.common import serialization from axlearn.common.config import ( @@ -66,6 +67,9 @@ register_validator, ) +# Define the version of JAX for compatibility on MemKind +_JAX_MEMORY_SPACE_SUPPORT = version.parse(jax.__version__) >= version.parse("0.7.0") + # New code should use Nested[XX] instead of NestedXX. # Old definitions are provided for backwards compatibility. _NestedT = TypeVar("_NestedT") @@ -118,7 +122,23 @@ def __len__(self): # "pinned_host" = Page locked memory on CPU, which can be address directly by accelerators by # direct memory access (DMA). For TPU, "pinned_host" memory layout follows TPU device tile # layout and usually cannot be zero-copy converted to a CPU-tensor. -MemoryKind = Literal["device", "pinned_host"] +if _JAX_MEMORY_SPACE_SUPPORT: + MemoryKind = [jax.memory.Space.Device, jax.memory.Space.Host] + DEVICE_MEMORY = jax.memory.Space.Device + HOST_MEMORY = jax.memory.Space.Host + + def transfer_to_memory_kind(tensor: Tensor, memory_kind: MemoryKind) -> Tensor: + return jax.device_put(tensor, memory_kind) + +else: + from jax._src.sharding_impls import TransferToMemoryKind # pylint: disable=ungrouped-imports + + MemoryKind = Literal["device", "pinned_host"] + DEVICE_MEMORY = "device" + HOST_MEMORY = "pinned_host" + + def transfer_to_memory_kind(tensor: Tensor, memory_kind: MemoryKind) -> Tensor: + return jax.device_put(tensor, TransferToMemoryKind(memory_kind)) @dataclasses.dataclass From 6ece93f2b6f55f6b018784c922714b97909c418a Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 9 Sep 2025 13:20:16 +0100 Subject: [PATCH 2/6] update triton to latest changes --- axlearn/common/flash_attention/gpu_attention.py | 10 ++++++++-- axlearn/common/flash_attention/gpu_decoding.py | 10 ++++++++-- .../common/flash_attention/gpu_paged_attention.py | 12 ++++++++++-- axlearn/common/kv_cache/paged_kv_cache_gpu_kernel.py | 10 ++++++++-- 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 9026f3a3e..1d47a2669 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -44,7 +44,6 @@ ) from jax.ad_checkpoint import checkpoint_name from jax.experimental import pallas as pl -from jax.experimental.pallas.triton import TritonCompilerParams from axlearn.common.attention_bias import ( NEG_INF, @@ -69,7 +68,14 @@ from axlearn.common.kv_cache.base_kv_cache import BaseKVCache from axlearn.common.kv_cache.kv_cache import KVCache from axlearn.common.layers import get_dropout_mask -from axlearn.common.utils import Nested, Tensor +from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Nested, Tensor + +# pylint: disable=ungrouped-imports +if _JAX_MEMORY_SPACE_SUPPORT: + from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams +else: + from jax.experimental.pallas.triton import TritonCompilerParams +# pylint: disable=ungrouped-imports def _segment_mask( diff --git a/axlearn/common/flash_attention/gpu_decoding.py b/axlearn/common/flash_attention/gpu_decoding.py index a29bdcc5c..1b5b07bb4 100644 --- a/axlearn/common/flash_attention/gpu_decoding.py +++ b/axlearn/common/flash_attention/gpu_decoding.py @@ -49,7 +49,6 @@ from jax import lax from jax._src.cudnn.fused_attention_stablehlo import check_compute_capability from jax.experimental import pallas as pl -from jax.experimental.pallas.triton import TritonCompilerParams from axlearn.common.attention_bias import ( NEG_INF, @@ -61,7 +60,14 @@ from axlearn.common.flash_attention.common import BaseSingleStepDecoding, get_gpu_dot_precision from axlearn.common.kv_cache.base_kv_cache import BaseKVCache from axlearn.common.kv_cache.kv_cache import KVCache -from axlearn.common.utils import Nested, Tensor +from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Nested, Tensor + +# pylint: disable=ungrouped-imports +if _JAX_MEMORY_SPACE_SUPPORT: + from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams +else: + from jax.experimental.pallas.triton import TritonCompilerParams +# pylint: enable=ungrouped-imports # Note: split_k_seq_len must be a multiple of block_k. diff --git a/axlearn/common/flash_attention/gpu_paged_attention.py b/axlearn/common/flash_attention/gpu_paged_attention.py index a2600de4a..9a3dd9d0e 100644 --- a/axlearn/common/flash_attention/gpu_paged_attention.py +++ b/axlearn/common/flash_attention/gpu_paged_attention.py @@ -19,7 +19,6 @@ import jax.numpy as jnp from jax import lax from jax.experimental import pallas as pl -from jax.experimental.pallas.triton import TritonCompilerParams from axlearn.common.attention_bias import ( NEG_INF, @@ -31,7 +30,16 @@ from axlearn.common.flash_attention.common import BasePagedAttention, get_gpu_dot_precision from axlearn.common.flash_attention.gpu_decoding import _get_sm_count as get_sm_count from axlearn.common.kv_cache.base_kv_cache import BaseKVCache -from axlearn.common.utils import Nested, Tensor +from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Nested, Tensor + +# pylint: disable=ungrouped-imports +if _JAX_MEMORY_SPACE_SUPPORT: + from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams +else: + from jax.experimental.pallas.triton import ( # isort: skip + TritonCompilerParams, + ) +# pylint: enable=ungrouped-imports def _paged_attention_kernel( diff --git a/axlearn/common/kv_cache/paged_kv_cache_gpu_kernel.py b/axlearn/common/kv_cache/paged_kv_cache_gpu_kernel.py index cbddb276d..5358d908c 100644 --- a/axlearn/common/kv_cache/paged_kv_cache_gpu_kernel.py +++ b/axlearn/common/kv_cache/paged_kv_cache_gpu_kernel.py @@ -8,9 +8,15 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl -from jax.experimental.pallas.triton import TritonCompilerParams -from axlearn.common.utils import Tensor +from axlearn.common.utils import _JAX_MEMORY_SPACE_SUPPORT, Tensor + +# pylint: disable=ungrouped-imports +if _JAX_MEMORY_SPACE_SUPPORT: + from jax.experimental.pallas.triton import CompilerParams as TritonCompilerParams +else: + from jax.experimental.pallas.triton import TritonCompilerParams +# pylint: enable=ungrouped-imports def _scatter_pages_kernel( From 5641f6cec9e444c773bfebc4aec039e821252628 Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 10 Oct 2025 14:02:50 +0200 Subject: [PATCH 3/6] temporary fix for envy --- axlearn/experiments/text/gpt/c4_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/axlearn/experiments/text/gpt/c4_trainer.py b/axlearn/experiments/text/gpt/c4_trainer.py index b9103cb7a..b8e8b22b2 100644 --- a/axlearn/experiments/text/gpt/c4_trainer.py +++ b/axlearn/experiments/text/gpt/c4_trainer.py @@ -49,7 +49,7 @@ from axlearn.common.input_lm import lm_text_preprocessor from axlearn.common.utils import get_data_dir from axlearn.experiments.text.common import DataMixtureComponent, vocab -from axlearn.experiments.text.gpt import envy, fuji, gspmd +from axlearn.experiments.text.gpt import fuji, gspmd from axlearn.experiments.text.gpt.common import mixture_train_input_source, tfds_input from axlearn.experiments.text.gpt.vocabulary_fuji_v3 import FujiV3Vocabulary @@ -109,5 +109,4 @@ def named_trainer_configs() -> dict[str, TrainerConfigFn]: config_map = {} config_map.update(fuji.trainer_configs(_train_input_source, _eval_input_sources)) config_map.update(gspmd.trainer_configs(_train_input_source, _eval_input_sources)) - config_map.update(envy.trainer_configs(_train_input_source, _eval_input_sources)) return config_map From 90cab09c231b367cf8dfd7b60494c3b198105ca2 Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 10 Oct 2025 14:23:26 +0200 Subject: [PATCH 4/6] add a check on jax version to solve the version error --- axlearn/common/array_serialization.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 369926e2e..5bd271e83 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -306,10 +306,12 @@ async def _async_serialize( and arr_inp.is_fully_addressable ) # pylint: disable=protected-access - spec_has_metadata = { - "0.6.2": lambda: serialization.ts_impl._spec_has_metadata, - "0.5.3": lambda: serialization._spec_has_metadata, - }[jax.__version__]() + if jax.__version__.startswith("0.8.0") or jax.__version__ == "0.6.2": + spec_has_metadata = serialization.ts_impl._spec_has_metadata + elif jax.__version__ == "0.5.3": + spec_has_metadata = serialization._spec_has_metadata + else: + raise ValueError(f"Unsupported JAX version for spec_has_metadata: {jax.__version__}") if not spec_has_metadata(tensorstore_spec): # pylint: disable-next=protected-access tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp) From d7f4bcb2575a8fc01f66809d7ed81680d35f856a Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 20 Oct 2025 18:55:02 +0200 Subject: [PATCH 5/6] modify pr --- axlearn/common/array_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 5bd271e83..729b122b6 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -306,7 +306,7 @@ async def _async_serialize( and arr_inp.is_fully_addressable ) # pylint: disable=protected-access - if jax.__version__.startswith("0.8.0") or jax.__version__ == "0.6.2": + if jax.__version__.startswith("0.8.") or jax.__version__ == "0.6.2": spec_has_metadata = serialization.ts_impl._spec_has_metadata elif jax.__version__ == "0.5.3": spec_has_metadata = serialization._spec_has_metadata From c211883d5325d2199ef077a91a44e5d34bfdf4a2 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 21 Oct 2025 12:06:37 +0200 Subject: [PATCH 6/6] use another strategy to get jax version' --- axlearn/common/array_serialization.py | 42 ++++++++++++++++++--------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 729b122b6..536af6c31 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -36,6 +36,7 @@ from jax._src import array, typing from jax._src.layout import Layout from jax.experimental.array_serialization import serialization +from packaging import version from axlearn.common.utils import Tensor @@ -75,6 +76,7 @@ def shard_coordinate(self): # Tuple (and thus hashable) representation of a slice object (start, end, step). _SliceTuple = tuple[Optional[int], Optional[int], Optional[int]] +JAX_VERSION = version.parse(jax.__version__) def _slices_to_tuple(slices: list[slice]) -> tuple[_SliceTuple, ...]: @@ -306,12 +308,13 @@ async def _async_serialize( and arr_inp.is_fully_addressable ) # pylint: disable=protected-access - if jax.__version__.startswith("0.8.") or jax.__version__ == "0.6.2": + if JAX_VERSION >= version.parse("0.6.2"): spec_has_metadata = serialization.ts_impl._spec_has_metadata - elif jax.__version__ == "0.5.3": + elif JAX_VERSION >= version.parse("0.5.3"): spec_has_metadata = serialization._spec_has_metadata else: raise ValueError(f"Unsupported JAX version for spec_has_metadata: {jax.__version__}") + if not spec_has_metadata(tensorstore_spec): # pylint: disable-next=protected-access tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp) @@ -488,10 +491,14 @@ async def _async_deserialize( async def cb(index: array.Index, device: jax.Device): requested_domain = ts.IndexTransform(input_shape=shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) - estimate_read_memory_footprint = { - "0.6.2": lambda: serialization.ts_impl.estimate_read_memory_footprint, - "0.5.3": lambda: serialization.estimate_read_memory_footprint, - }[jax.__version__]() + if JAX_VERSION >= version.parse("0.6.2"): + estimate_read_memory_footprint = serialization.ts_impl.estimate_read_memory_footprint + elif JAX_VERSION >= version.parse("0.5.3"): + estimate_read_memory_footprint = serialization.estimate_read_memory_foot_print + else: + raise ValueError( + f"Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer" + ) requested_bytes = estimate_read_memory_footprint(t, restricted_domain) # Limit the bytes read for every shard. await byte_limiter.wait_for_bytes(requested_bytes) @@ -569,10 +576,13 @@ async def cb(index: array.Index, device: jax.Device): return result # pylint: disable=protected-access - create_async_array_from_callback = { - "0.6.2": lambda: serialization.ts_impl._create_async_array_from_callback, - "0.5.3": lambda: serialization.create_async_array_from_callback, - }[jax.__version__]() + if JAX_VERSION >= version.parse("0.6.2"): + create_async_array_from_callback = serialization.ts_impl._create_async_array_from_callback + elif JAX_VERSION >= version.parse("0.5.3"): + create_async_array_from_callback = serialization.create_async_array_from_callback + else: + raise ValueError("Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer.") + return await create_async_array_from_callback(shape, in_sharding, cb) @@ -654,10 +664,14 @@ def serialize( commit_futures = [[] for _ in range(len(tensorstore_specs))] - async_serialize = { - "0.6.2": lambda: serialization.ts_impl.async_serialize, - "0.5.3": lambda: serialization.async_serialize, - }[jax.__version__]() + if JAX_VERSION >= version.parse("0.6.2"): + async_serialize = serialization.ts_impl.async_serialize + elif JAX_VERSION >= version.parse("0.5.3"): + async_serialize = serialization.async_serialize + else: + raise ValueError( + f"Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer." + ) # pylint: disable-next=redefined-outer-name async def _run_serializer():