From 331bcd20092fec32ec8077c2febab8272b58f0ba Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 11 Sep 2025 09:26:36 -0700 Subject: [PATCH 01/50] add base benchmark scripts --- benchmark_deserialize.py | 256 +++++++++++++++++++++++++++++++++++++++ proxy_bench.py | 88 ++++++++++++++ 2 files changed, 344 insertions(+) create mode 100644 benchmark_deserialize.py create mode 100644 proxy_bench.py diff --git a/benchmark_deserialize.py b/benchmark_deserialize.py new file mode 100644 index 000000000..c625e795b --- /dev/null +++ b/benchmark_deserialize.py @@ -0,0 +1,256 @@ +""" +A script to benchmark the GlobalAsyncCheckpointManager.deserialize function. + +This script contains a local patch for the deserialization logic to work around +a bug in the installed axlearn library, avoiding any modification to the library itself. +""" + +import asyncio +import functools +import math +import os +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Optional, Sequence, Union + +import jax +import jax.numpy as jnp +import numpy as np +import tensorstore as ts +from absl import app, flags +from jax._src import array, typing +from jax._src.layout import Layout +from jax.experimental.array_serialization import serialization +from jax.experimental.array_serialization.serialization import get_tensorstore_spec +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from axlearn.common.array_serialization import ( + GlobalAsyncCheckpointManager, + _get_premapped_buffer_size, +) +from axlearn.common.checkpointer import read_state_spec +from axlearn.common.utils import flatten_items + +# JAX platforms might be initialized by another process. +# We follow the logic in axlearn.common.launch to initialize JAX. +if os.environ.get("JAX_PLATFORMS", "") == "proxy": + import pathwaysutils # type: ignore + + pathwaysutils.initialize() +else: + jax.distributed.initialize() + + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "checkpoint_dir", + "gs://cloud-tpu-multipod-dev-axlearn/stoelinga-v7-70b-17/checkpoints/step_00000100/", + "The GCS path to the checkpoint step directory.", +) +flags.DEFINE_integer("num_iterations", 5, "The number of benchmark iterations.") +flags.DEFINE_integer("warmup_iterations", 1, "The number of warmup iterations.") +flags.DEFINE_integer("batch_size", 64, "The number of tensors to deserialize in each batch.") + + +# --- Local Patch for Deserialization --- +# The following functions are copied from axlearn.common.array_serialization +# and patched locally to fix a TypeError without modifying the library. + + +def _blocking_device_put(out: np.ndarray, layout: Layout) -> jax.Array: + return jax.block_until_ready(jax.device_put(out, layout)) + + +async def _patched_async_deserialize( + user_in_sharding: jax.sharding.Sharding | Layout, + tensorstore_spec: dict[str, Any], + global_shape: Optional[Sequence[int]], + dtype: Optional[typing.DTypeLike], + *, + h2d_limiter: serialization._LimitInFlightBytes, + byte_limiter: serialization._LimitInFlightBytes, + single_thread_pool: ThreadPoolExecutor, +): + """Patched version of _async_deserialize.""" + in_sharding = ( + user_in_sharding.sharding if isinstance(user_in_sharding, Layout) else user_in_sharding + ) + if not isinstance(in_sharding, jax.sharding.Sharding): + raise ValueError( + "sharding passed to deserialization should be specified, concrete and" + f" an instance of `jax.sharding.Sharding`. Got {in_sharding}" + ) + dll = user_in_sharding.device_local_layout if isinstance(user_in_sharding, Layout) else None + t = await ts.open( + tensorstore_spec, + open=True, + assume_metadata=False, + context=serialization.TS_CONTEXT, + ) + shape = tuple(t.shape if global_shape is None else global_shape) + new_shard_shape = in_sharding.shard_shape(shape) + loop = asyncio.get_running_loop() + + async def cb(index: array.Index, device: jax.Device): + requested_domain = ts.IndexTransform(input_shape=shape)[index].domain + restricted_domain = t.domain.intersect(requested_domain) + requested_bytes = serialization.ts_impl.estimate_read_memory_footprint(t, restricted_domain) + await byte_limiter.wait_for_bytes(requested_bytes) + read_ts = t[restricted_domain] + if dtype is not None: + read_ts = ts.cast(read_ts, dtype) + if tuple(t.shape) == shape: + out = np.empty(new_shard_shape, read_ts.dtype.numpy_dtype) + else: + out = np.zeros(new_shard_shape, read_ts.dtype.numpy_dtype) + + await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write( + read_ts + ) + + if out.dtype == jnp.int4: + out = jnp.asarray(out) + + out_size = out.size * out.dtype.itemsize + mb_256 = 256 * 1024 * 1024 + out_size = math.ceil(out_size / mb_256) * mb_256 + + sharding_for_put = jax.sharding.SingleDeviceSharding( + device, memory_kind=in_sharding.memory_kind + ) + if dll is not None: + sharding_for_put = Layout(dll, sharding_for_put) + + try: + await h2d_limiter.wait_for_bytes(out_size) + result = await loop.run_in_executor(None, _blocking_device_put, out, sharding_for_put) + await h2d_limiter.release_bytes(out_size) + except ValueError as e: + if "Requested more bytes than we reserved" not in str(e): + raise e + result = await loop.run_in_executor( + single_thread_pool, _blocking_device_put, out, sharding_for_put + ) + + await byte_limiter.release_bytes(requested_bytes) + return result + + # This is the patched line. + # pylint: disable-next=protected-access + return await serialization.ts_impl._create_async_array_from_callback( + shape, dtype, in_sharding, cb + ) + + +class PatchedGlobalAsyncCheckpointManager(GlobalAsyncCheckpointManager): + """An override of the manager to use our patched deserialize logic.""" + + def deserialize( + self, + shardings: Sequence[Union[jax.sharding.Sharding, Layout]], + tensorstore_specs: Sequence[dict[str, Any]], + global_shapes: Optional[Sequence[array.Shape]] = None, + dtypes: Optional[Sequence[typing.DTypeLike]] = None, + concurrent_gb: int = 32, + ): + self.wait_until_finished() + concurrent_bytes = concurrent_gb * 10**9 + + async def _run_deserializer(): + # pylint: disable=protected-access + byte_limiter = serialization._LimitInFlightBytes(concurrent_bytes) + h2d_limiter = serialization._LimitInFlightBytes(_get_premapped_buffer_size()) + future_arrays = jax.tree.map( + functools.partial( + _patched_async_deserialize, # Use our patched function. + byte_limiter=byte_limiter, + h2d_limiter=h2d_limiter, + single_thread_pool=self._single_thread_pool, + ), + shardings, + tensorstore_specs, + [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, + [None] * len(tensorstore_specs) if dtypes is None else dtypes, + ) + return await asyncio.gather(*future_arrays) + + fut = asyncio.run_coroutine_threadsafe(_run_deserializer(), self._loop) + return fut.result() + + +def main(argv: Sequence[str]) -> None: + """Benchmarks the deserialize function.""" + del argv + + devices = jax.devices() + mesh = Mesh(devices, axis_names=("data",)) + + state_spec = read_state_spec(FLAGS.checkpoint_dir) + flat_state_spec = flatten_items(state_spec, separator="/") + + ts_specs, shardings_list, global_shapes, dtypes = [], [], [], [] + + for path, spec in flat_state_spec: + gda_path = os.path.join(FLAGS.checkpoint_dir, "gda", path) + ts_specs.append(get_tensorstore_spec(gda_path)) + + partition_spec = PartitionSpec() + if len(spec.shape) > 0 and spec.shape[0] % len(devices) == 0: + partition_spec = PartitionSpec("data", *(None,) * (len(spec.shape) - 1)) + + shardings_list.append(NamedSharding(mesh, partition_spec)) + global_shapes.append(spec.shape) + dtypes.append(spec.dtype) + + manager = PatchedGlobalAsyncCheckpointManager() + + def run_deserialize(): + """Runs deserialization across all tensors, processing them in batches.""" + total_duration = 0 + num_tensors = len(ts_specs) + for i in range(0, num_tensors, FLAGS.batch_size): + batch_start_time = time.time() + batch_end = min(i + FLAGS.batch_size, num_tensors) + print( + f" Deserializing batch {i // FLAGS.batch_size + 1}/" + f"{math.ceil(num_tensors / FLAGS.batch_size)} (tensors {i}-{batch_end-1})..." + ) + + restored_arrays = manager.deserialize( + shardings=shardings_list[i:batch_end], + tensorstore_specs=ts_specs[i:batch_end], + global_shapes=global_shapes[i:batch_end], + dtypes=dtypes[i:batch_end], + ) + for arr in restored_arrays: + arr.block_until_ready() + + batch_duration = time.time() - batch_start_time + total_duration += batch_duration + return total_duration + + print(f"Running {FLAGS.warmup_iterations} warmup iterations...") + for _ in range(FLAGS.warmup_iterations): + run_deserialize() + + print(f"Running {FLAGS.num_iterations} benchmark iterations...") + durations = [] + for i in range(FLAGS.num_iterations): + duration = run_deserialize() + print(f"Iteration {i+1} took {duration:.4f} seconds.") + durations.append(duration) + + print("\n--- Benchmark Results ---") + print(f"Number of devices: {len(devices)}") + print(f"Iterations: {FLAGS.num_iterations}") + print(f"Average time: {sum(durations) / len(durations):.4f} seconds") + print(f"Min time: {min(durations):.4f} seconds") + print(f"Max time: {max(durations):.4f} seconds") + print("-------------------------\n") + + manager.stop() + + +if __name__ == "__main__": + app.run(main) diff --git a/proxy_bench.py b/proxy_bench.py new file mode 100644 index 000000000..ca5312d34 --- /dev/null +++ b/proxy_bench.py @@ -0,0 +1,88 @@ +"""A script to benchmark JAX device_put throughput.""" + +import os +import time + +import jax +import numpy as np +import pathwaysutils +from jax.sharding import Mesh, NamedSharding, PartitionSpec + + +def benchmark_host_to_device_throughput(device_put_buffer_mb: int = 1024): + """Benchmarks JAX device_put throughput from CPU host to a v5e-32 TPU slice.""" + print(f"JAX version: {jax.__version__}") + devices = jax.devices() if os.environ.get("JAX_PLATFORMS") else jax.local_devices() + num_devices = len(devices) + print(f"Available devices: {num_devices}") + + data_bytes_per_device = int(device_put_buffer_mb * 1024 * 1024) # 1 GiB + dtype = np.float32 + num_elements = data_bytes_per_device // np.dtype(dtype).itemsize + data_gb_per_device = num_elements * np.dtype(dtype).itemsize / (1024**3) + + print( + f"Creating a NumPy array of shape ({num_elements},) type {dtype}, size" + f" {data_gb_per_device:.2f} GiB" + ) + host_array = np.arange(num_elements, dtype=dtype) + + # Create a mesh spanning all devices. + mesh = Mesh(np.array(devices), axis_names=("i",)) + # An empty PartitionSpec() means the array is fully replicated across all + # devices in the mesh. + replicated_sharding = NamedSharding(mesh, PartitionSpec()) + print(f"Using sharding for replication: {replicated_sharding}") + + # Warm-up transfer + print("Performing warm-up transfer...") + try: + dummy_array = jax.device_put(host_array, replicated_sharding) + dummy_array.block_until_ready() + print("Warm-up complete.") + except RuntimeError as e: + print(f"Error during warm-up: {e}") + return + + # Benchmark loop + num_transfers = 5 + transfer_times = [] + + print(f"Starting benchmark ({num_transfers} transfers)...") + for i in range(num_transfers): + if i == 0: + trace_dir = "gs://cloud-tpu-multipod-dev-axlearn/stoelinga-proxy-benchmark" + jax.profiler.start_trace(f"{trace_dir}/{device_put_buffer_mb}mb") + start_time = time.perf_counter() + device_array = jax.device_put(host_array, replicated_sharding) + device_array.block_until_ready() + end_time = time.perf_counter() + + duration = end_time - start_time + transfer_times.append(duration) + print(f"Transfer {i+1}/{num_transfers}: {duration:.4f} seconds") + if i == 0: + jax.profiler.stop_trace() + del device_array # Optional: hint for early deletion + + avg_time = np.mean(transfer_times) + print(f"\nAverage time per device_put call: {avg_time:.4f} seconds") + + total_data_moved_gb = data_gb_per_device * num_devices + throughput_gb_s = total_data_moved_gb / avg_time + + print(f"Data per device: {data_gb_per_device:.2f} GiB") + print("Total data transferred from host per operation:" f" {total_data_moved_gb:.2f} GiB") + print(f"Aggregated Host -> Devices Throughput: {throughput_gb_s:.2f} GiB/s") + print(f"Aggregated Host -> Devices Throughput: {throughput_gb_s * 8:.2f} Gbps/s") + + +if __name__ == "__main__": + if os.environ.get("JAX_PLATFORMS") == "proxy": + pathwaysutils.initialize() + else: + jax.distributed.initialize() + scenarios_mb = [1, 128, 1024, 2048] + for scenario in scenarios_mb: + print(f"Running scenario {scenario}MB") + benchmark_host_to_device_throughput(scenario) From 2c01b114cfaa13b1353bc88ac0fc15c64c745dac Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 11 Sep 2025 09:28:50 -0700 Subject: [PATCH 02/50] remove pathways cpu and memory limit --- axlearn/cloud/gcp/pathways_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index fa47ef64b..986d5ecaf 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -107,7 +107,7 @@ def get_pathways_tpu_version(gke_machine_type: str) -> str: def get_megascale_options( - xla_options: dict[str, Union[str, bool, int]] + xla_options: dict[str, Union[str, bool, int]], ) -> dict[str, Union[str, bool, int]]: """Filters XLA options for those pertaining to Megascale. @@ -122,7 +122,7 @@ def get_megascale_options( def get_xla_options( - xla_options: dict[str, Union[str, bool, int]] + xla_options: dict[str, Union[str, bool, int]], ) -> dict[str, Union[str, bool, int]]: """Filters XLA options for those starting with 'xla_'. @@ -315,7 +315,7 @@ def _build_pathways_head_container(self) -> dict: mem_req = f"{self.config.pathways_head_mem}Gi" resources = { "requests": {"cpu": cpu_req, "memory": mem_req}, - "limits": {"cpu": cpu_req, "memory": mem_req}, + # "limits": {"cpu": cpu_req, "memory": mem_req}, } head_container["resources"] = resources @@ -910,7 +910,7 @@ def _build_head_container(self) -> dict: mem_req = f"{self.config.pathways_head_mem}Gi" resources = { "requests": {"cpu": cpu_req, "memory": mem_req}, - "limits": {"cpu": cpu_req, "memory": mem_req}, + # "limits": {"cpu": cpu_req, "memory": mem_req}, } return dict( name=cfg.name, @@ -936,9 +936,9 @@ def _build_head_container(self) -> dict: ], imagePullPolicy="Always", resources=resources, - ports=[dict(containerPort=self.config.target_port)] - if self.config.enable_service - else [], + ports=( + [dict(containerPort=self.config.target_port)] if self.config.enable_service else [] + ), ) def build_leader_pod(self) -> Nested[Any]: From 2876881ce1b92cdfce8eb12b274f6350a8f3bf7a Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 11 Sep 2025 09:37:23 -0700 Subject: [PATCH 03/50] remove batching --- benchmark_deserialize.py | 35 +++++++++++------------------------ 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/benchmark_deserialize.py b/benchmark_deserialize.py index c625e795b..0a0f20d92 100644 --- a/benchmark_deserialize.py +++ b/benchmark_deserialize.py @@ -50,7 +50,6 @@ ) flags.DEFINE_integer("num_iterations", 5, "The number of benchmark iterations.") flags.DEFINE_integer("warmup_iterations", 1, "The number of warmup iterations.") -flags.DEFINE_integer("batch_size", 64, "The number of tensors to deserialize in each batch.") # --- Local Patch for Deserialization --- @@ -206,29 +205,17 @@ def main(argv: Sequence[str]) -> None: manager = PatchedGlobalAsyncCheckpointManager() def run_deserialize(): - """Runs deserialization across all tensors, processing them in batches.""" - total_duration = 0 - num_tensors = len(ts_specs) - for i in range(0, num_tensors, FLAGS.batch_size): - batch_start_time = time.time() - batch_end = min(i + FLAGS.batch_size, num_tensors) - print( - f" Deserializing batch {i // FLAGS.batch_size + 1}/" - f"{math.ceil(num_tensors / FLAGS.batch_size)} (tensors {i}-{batch_end-1})..." - ) - - restored_arrays = manager.deserialize( - shardings=shardings_list[i:batch_end], - tensorstore_specs=ts_specs[i:batch_end], - global_shapes=global_shapes[i:batch_end], - dtypes=dtypes[i:batch_end], - ) - for arr in restored_arrays: - arr.block_until_ready() - - batch_duration = time.time() - batch_start_time - total_duration += batch_duration - return total_duration + """Runs deserialization across all tensors.""" + start_time = time.time() + restored_arrays = manager.deserialize( + shardings=shardings_list, + tensorstore_specs=ts_specs, + global_shapes=global_shapes, + dtypes=dtypes, + ) + for arr in restored_arrays: + arr.block_until_ready() + return time.time() - start_time print(f"Running {FLAGS.warmup_iterations} warmup iterations...") for _ in range(FLAGS.warmup_iterations): From 4a9f1f835edfabcdeadd0733354eb37ce3ea6bbd Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 11 Sep 2025 09:43:07 -0700 Subject: [PATCH 04/50] fix for jax 0.5.3 --- benchmark_deserialize.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/benchmark_deserialize.py b/benchmark_deserialize.py index 0a0f20d92..713954501 100644 --- a/benchmark_deserialize.py +++ b/benchmark_deserialize.py @@ -94,7 +94,7 @@ async def _patched_async_deserialize( async def cb(index: array.Index, device: jax.Device): requested_domain = ts.IndexTransform(input_shape=shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) - requested_bytes = serialization.ts_impl.estimate_read_memory_footprint(t, restricted_domain) + requested_bytes = serialization.estimate_read_memory_footprint(t, restricted_domain) await byte_limiter.wait_for_bytes(requested_bytes) read_ts = t[restricted_domain] if dtype is not None: @@ -137,9 +137,7 @@ async def cb(index: array.Index, device: jax.Device): # This is the patched line. # pylint: disable-next=protected-access - return await serialization.ts_impl._create_async_array_from_callback( - shape, dtype, in_sharding, cb - ) + return await serialization.create_async_array_from_callback(shape, in_sharding, cb) class PatchedGlobalAsyncCheckpointManager(GlobalAsyncCheckpointManager): From 1262bf7a8bb9abe1d69323b352d00d88742f322a Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 11 Sep 2025 09:44:37 -0700 Subject: [PATCH 05/50] limit amount of max bytes being restored --- benchmark_deserialize.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/benchmark_deserialize.py b/benchmark_deserialize.py index 713954501..490a277e9 100644 --- a/benchmark_deserialize.py +++ b/benchmark_deserialize.py @@ -154,6 +154,19 @@ def deserialize( self.wait_until_finished() concurrent_bytes = concurrent_gb * 10**9 + max_shard_bytes = 0 + if global_shapes and dtypes: + for sharding, shape, dtype in zip(shardings, global_shapes, dtypes): + if isinstance(sharding, Layout): + sharding = sharding.sharding + shard_shape = sharding.shard_shape(shape) + shard_bytes = np.prod(shard_shape) * np.dtype(dtype).itemsize + if shard_bytes > max_shard_bytes: + max_shard_bytes = shard_bytes + + if max_shard_bytes > concurrent_bytes: + concurrent_bytes = int(max_shard_bytes) + 1 + async def _run_deserializer(): # pylint: disable=protected-access byte_limiter = serialization._LimitInFlightBytes(concurrent_bytes) From c057ecc5b0793eed031a8d388c1d5e840d2be907 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 11 Sep 2025 17:47:19 -0700 Subject: [PATCH 06/50] WIP: send all jax device puts in parallel with pathways --- axlearn/common/array_serialization.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 9ba0bbf81..3e993ab71 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -205,8 +205,7 @@ def _fix_metadata(tspec: dict[str, Any], shard_infos: list[_ShardInfo]): class TensorstoreSpecModifier: - def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]): - ... + def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]): ... async def _async_serialize( @@ -454,9 +453,12 @@ async def cb(index: array.Index, device: jax.Device): dll, jax.sharding.SingleDeviceSharding(device, memory_kind=in_sharding.memory_kind) ) try: - await h2d_limiter.wait_for_bytes(out_size) - result = await loop.run_in_executor(None, _blocking_device_put, out, layout) - await h2d_limiter.release_bytes(out_size) + if os.getenv("JAX_PLATFORMS") == "proxy": + result = await loop.run_in_executor(None, _blocking_device_put, out, layout) + else: + await h2d_limiter.wait_for_bytes(out_size) + result = await loop.run_in_executor(None, _blocking_device_put, out, layout) + await h2d_limiter.release_bytes(out_size) except ValueError as e: if "Requested more bytes than we reserved" not in str(e): raise e # Raise if it's not the type of error we expect. From ca6dd86960ba2e63956f143c27a820fbdd84ddd7 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 11 Sep 2025 18:28:01 -0700 Subject: [PATCH 07/50] also prevent byte limiter --- axlearn/common/array_serialization.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 3e993ab71..75f3e7b12 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -418,7 +418,8 @@ async def cb(index: array.Index, device: jax.Device): restricted_domain = t.domain.intersect(requested_domain) requested_bytes = serialization.estimate_read_memory_footprint(t, restricted_domain) # Limit the bytes read for every shard. - await byte_limiter.wait_for_bytes(requested_bytes) + if os.getenv("JAX_PLATFORMS") != "proxy": + await byte_limiter.wait_for_bytes(requested_bytes) read_ts = t[restricted_domain] # Use ts.cast rather than np.astype since ts can perform casting on-the-fly. if dtype is not None: @@ -477,7 +478,8 @@ async def cb(index: array.Index, device: jax.Device): single_thread_pool, _blocking_device_put, out, layout ) - await byte_limiter.release_bytes(requested_bytes) + if os.getenv("JAX_PLATFORMS") != "proxy": + await byte_limiter.release_bytes(requested_bytes) return result return await serialization.create_async_array_from_callback(shape, in_sharding, cb) From 2ae8a5415dfbdd09a714d233cb8344b325a52d0b Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 10:10:45 -0700 Subject: [PATCH 08/50] fuji pdbs=1 --- axlearn/experiments/text/gpt/fuji.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 9ec469dbb..08e5adda8 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -15,6 +15,7 @@ import itertools from typing import Any, List, NamedTuple, Optional, Union +import jax from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies from axlearn.common import causal_lm, config @@ -841,6 +842,8 @@ def get_trainer_kwargs( ) else: raise NotImplementedError(f"Unknown model size {model_size}.") + total_chips = len(jax.devices()) + trainer_kwargs["train_batch_size"] = total_chips model_kwargs = trainer_kwargs.pop("model_kwargs") model_kwargs.setdefault("vocab_size", vocab_size) if version == Version.V3_TIKTOKEN: # tiktoken tokenizer From 6d84ab09f39ed4c5dcd8fbf2d74c53ad9860dc0a Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 10:25:25 -0700 Subject: [PATCH 09/50] add logging of device puts --- axlearn/common/array_serialization.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 75f3e7b12..4c1b2827e 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -454,12 +454,21 @@ async def cb(index: array.Index, device: jax.Device): dll, jax.sharding.SingleDeviceSharding(device, memory_kind=in_sharding.memory_kind) ) try: + log_id = id(out) + logging.info( + "Sending jax.device_put of size %s MiB. Shape: %s. ID: %s", + out_size / (1024 * 1024), + out.shape, + log_id, + ) + start_time = time.time() if os.getenv("JAX_PLATFORMS") == "proxy": result = await loop.run_in_executor(None, _blocking_device_put, out, layout) else: await h2d_limiter.wait_for_bytes(out_size) result = await loop.run_in_executor(None, _blocking_device_put, out, layout) await h2d_limiter.release_bytes(out_size) + logging.info("Device put took %.4f seconds. ID: %s", time.time() - start_time, log_id) except ValueError as e: if "Requested more bytes than we reserved" not in str(e): raise e # Raise if it's not the type of error we expect. From cf282c6d796de5cc80913c74d4bd94895845f457 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 10:30:16 -0700 Subject: [PATCH 10/50] add profiling of checkpoint load --- axlearn/common/array_serialization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 4c1b2827e..231b86f34 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -602,6 +602,7 @@ def deserialize( concurrent_gb: int = 32, ): self.wait_until_finished() + jax.profiler.start_trace("gs://cloud-tpu-multipod-dev-uss1/stoelinga-profile-1/") concurrent_bytes = concurrent_gb * 10**9 @@ -626,7 +627,9 @@ async def _run_deserializer(): return await asyncio.gather(*future_arrays) fut = asyncio.run_coroutine_threadsafe(_run_deserializer(), self._loop) - return fut.result() + result = fut.result() + jax.profiler.stop_trace() + return result class BoundedDataShardedAsyncCheckpointManager(GlobalAsyncCheckpointManager): From 36351ed197880d502cff1dd65a6b20d7912f7e44 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 10:42:22 -0700 Subject: [PATCH 11/50] Update fuji 70b mesh for v5e --- axlearn/experiments/text/gpt/fuji.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 08e5adda8..62ac78ede 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -705,6 +705,31 @@ def get_trainer_kwargs( ], ), ), + ( + "tpu-v5e-.*", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=False, + policy=config_for_function( + save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved=None, + names_which_can_be_offloaded=None, + offload_src="device", + offload_dst="pinned_host", + ), + ), + } + ), + ], + ), + ), ( "tpu-v5p-.*", ChainConfigModifier.default_config().set( From 793fb65ffae1e93ff414ec1335c38163be015313 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 10:57:46 -0700 Subject: [PATCH 12/50] bring back the limiters --- axlearn/common/array_serialization.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 231b86f34..97b06b8b9 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -418,8 +418,7 @@ async def cb(index: array.Index, device: jax.Device): restricted_domain = t.domain.intersect(requested_domain) requested_bytes = serialization.estimate_read_memory_footprint(t, restricted_domain) # Limit the bytes read for every shard. - if os.getenv("JAX_PLATFORMS") != "proxy": - await byte_limiter.wait_for_bytes(requested_bytes) + await byte_limiter.wait_for_bytes(requested_bytes) read_ts = t[restricted_domain] # Use ts.cast rather than np.astype since ts can perform casting on-the-fly. if dtype is not None: @@ -462,12 +461,9 @@ async def cb(index: array.Index, device: jax.Device): log_id, ) start_time = time.time() - if os.getenv("JAX_PLATFORMS") == "proxy": - result = await loop.run_in_executor(None, _blocking_device_put, out, layout) - else: - await h2d_limiter.wait_for_bytes(out_size) - result = await loop.run_in_executor(None, _blocking_device_put, out, layout) - await h2d_limiter.release_bytes(out_size) + await h2d_limiter.wait_for_bytes(out_size) + result = await loop.run_in_executor(None, _blocking_device_put, out, layout) + await h2d_limiter.release_bytes(out_size) logging.info("Device put took %.4f seconds. ID: %s", time.time() - start_time, log_id) except ValueError as e: if "Requested more bytes than we reserved" not in str(e): @@ -487,8 +483,7 @@ async def cb(index: array.Index, device: jax.Device): single_thread_pool, _blocking_device_put, out, layout ) - if os.getenv("JAX_PLATFORMS") != "proxy": - await byte_limiter.release_bytes(requested_bytes) + await byte_limiter.release_bytes(requested_bytes) return result return await serialization.create_async_array_from_callback(shape, in_sharding, cb) From 0ab519fdf55534fa24beeda0664c00763e18400c Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 11:56:50 -0700 Subject: [PATCH 13/50] add log to print total restore time --- axlearn/common/array_serialization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 97b06b8b9..c914ae722 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -597,6 +597,7 @@ def deserialize( concurrent_gb: int = 32, ): self.wait_until_finished() + start_time = time.time() jax.profiler.start_trace("gs://cloud-tpu-multipod-dev-uss1/stoelinga-profile-1/") concurrent_bytes = concurrent_gb * 10**9 @@ -624,6 +625,7 @@ async def _run_deserializer(): fut = asyncio.run_coroutine_threadsafe(_run_deserializer(), self._loop) result = fut.result() jax.profiler.stop_trace() + logging.info("deserialize took %.4f seconds.", time.time() - start_time) return result From eea2db5eaf1c825f2a44dc9ec3234cb5c19a2d06 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 12:05:17 -0700 Subject: [PATCH 14/50] try deleting from HBM after device_put --- axlearn/common/array_serialization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index c914ae722..abafdd530 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -465,6 +465,8 @@ async def cb(index: array.Index, device: jax.Device): result = await loop.run_in_executor(None, _blocking_device_put, out, layout) await h2d_limiter.release_bytes(out_size) logging.info("Device put took %.4f seconds. ID: %s", time.time() - start_time, log_id) + # We delete afterwards from HBM since we're testing on v5e with limited HBM + result.delete() except ValueError as e: if "Requested more bytes than we reserved" not in str(e): raise e # Raise if it's not the type of error we expect. From 0894e2d7afe02908a5a60febea114b553c6f4766 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 12:25:36 -0700 Subject: [PATCH 15/50] save every 100 steps fuji 7b and remove delete --- axlearn/common/array_serialization.py | 2 +- axlearn/experiments/text/gpt/fuji.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index abafdd530..6646f9f58 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -466,7 +466,7 @@ async def cb(index: array.Index, device: jax.Device): await h2d_limiter.release_bytes(out_size) logging.info("Device put took %.4f seconds. ID: %s", time.time() - start_time, log_id) # We delete afterwards from HBM since we're testing on v5e with limited HBM - result.delete() + # result.delete(), this didn't work it causes instance device_puts except ValueError as e: if "Requested more bytes than we reserved" not in str(e): raise e # Raise if it's not the type of error we expect. diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 62ac78ede..775156646 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -410,6 +410,7 @@ def get_trainer_kwargs( learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, + save_every_n_steps=100, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( From ddb6b03f677ebd27110d26c9420539169440bf11 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 12:30:28 -0700 Subject: [PATCH 16/50] dont save any remats fuji 7b --- axlearn/experiments/text/gpt/fuji.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 775156646..fc7f107ef 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -425,7 +425,7 @@ def get_trainer_kwargs( ("tpu-v4-(1024|2048)", mesh_shape_from_axes(data=-1, fsdp=16)), # tpu-v5e. ( - "tpu-v5litepod-256", + "tpu-v5litepod-32-1", ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( @@ -435,11 +435,18 @@ def get_trainer_kwargs( remat_policies={ "model.decoder.transformer.layer": RematSpec( prevent_cse=False, - policy=offload_dots_saveable_policy, + policy=config_for_function( + save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved=None, + names_which_can_be_offloaded=None, + offload_src="device", + offload_dst="pinned_host", + ), ), } ), - GradientAccumulationModifier.default_config().set(grad_acc_steps=4), + # GradientAccumulationModifier.default_config().set(grad_acc_steps=4), ], ), ), From 1f87664604537c8192000be019b3ffe961c4e763 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 12:32:26 -0700 Subject: [PATCH 17/50] 7b fsdp=32 --- axlearn/experiments/text/gpt/fuji.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index fc7f107ef..3373f75f4 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -429,7 +429,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + mesh_shape=mesh_shape_from_axes(fsdp=32) ), RematSpecModifier.default_config().set( remat_policies={ From 68bf1221f0f06db3f74074d69e311479f37efec7 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 13:11:10 -0700 Subject: [PATCH 18/50] exit after deserialize --- axlearn/common/array_serialization.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 6646f9f58..86cbf3511 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -20,6 +20,7 @@ import functools import math import os +import sys import threading import time from collections import defaultdict @@ -628,6 +629,8 @@ async def _run_deserializer(): result = fut.result() jax.profiler.stop_trace() logging.info("deserialize took %.4f seconds.", time.time() - start_time) + sys.exit(0) + # pylint: disable=unreachable return result From e30c7ce160818ec2c4e514c18bd9f0e7cc057fe9 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 13:23:42 -0700 Subject: [PATCH 19/50] concurrent restore 128GB --- axlearn/common/array_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 86cbf3511..a485eb8ff 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -597,7 +597,7 @@ def deserialize( tensorstore_specs: Sequence[dict[str, Any]], global_shapes: Optional[Sequence[array.Shape]] = None, dtypes: Optional[Sequence[typing.DTypeLike]] = None, - concurrent_gb: int = 32, + concurrent_gb: int = 128, ): self.wait_until_finished() start_time = time.time() From 213fa24cac9c5331d506f302c944f4a00b69e2e7 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 13:28:43 -0700 Subject: [PATCH 20/50] print total time before stopping trace --- axlearn/common/array_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index a485eb8ff..e77bbe252 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -627,8 +627,8 @@ async def _run_deserializer(): fut = asyncio.run_coroutine_threadsafe(_run_deserializer(), self._loop) result = fut.result() - jax.profiler.stop_trace() logging.info("deserialize took %.4f seconds.", time.time() - start_time) + jax.profiler.stop_trace() sys.exit(0) # pylint: disable=unreachable return result From d0e45e62d66941fabd559533782a2c9b85cb9d2d Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 13:32:15 -0700 Subject: [PATCH 21/50] add logging of concurrent_gb --- axlearn/common/array_serialization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index e77bbe252..480313eb3 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -599,6 +599,7 @@ def deserialize( dtypes: Optional[Sequence[typing.DTypeLike]] = None, concurrent_gb: int = 128, ): + logging.info("concurrent_gb=%s GB.", concurrent_gb) self.wait_until_finished() start_time = time.time() jax.profiler.start_trace("gs://cloud-tpu-multipod-dev-uss1/stoelinga-profile-1/") From 4fc79a1536542ad5e51e56c9c64599d212712696 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 13:35:32 -0700 Subject: [PATCH 22/50] force to 128gb for real this time --- axlearn/common/array_serialization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 480313eb3..da1e87c04 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -599,6 +599,8 @@ def deserialize( dtypes: Optional[Sequence[typing.DTypeLike]] = None, concurrent_gb: int = 128, ): + # force to 128 + concurrent_gb = max(128, concurrent_gb) logging.info("concurrent_gb=%s GB.", concurrent_gb) self.wait_until_finished() start_time = time.time() From d71ce4adc70b42bdbbe8072387fbef920c5dce39 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 13:39:20 -0700 Subject: [PATCH 23/50] add improvements to GCS perf --- axlearn/common/array_serialization.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index da1e87c04..33c50120f 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -404,11 +404,24 @@ async def _async_deserialize( f" an instance of `jax.sharding.Sharding`. Got {in_sharding}" ) dll = user_in_sharding.device_local_layout if isinstance(user_in_sharding, Layout) else None + + # gcs_grpc improves performance. + if tensorstore_spec.get("kvstore", {}).get("driver", "") == "gcs": + tensorstore_spec["kvstore"]["driver"] = "gcs_grpc" + t = await ts.open( tensorstore_spec, open=True, assume_metadata=False, - context=serialization.TS_CONTEXT, + # context=serialization.TS_CONTEXT, + # Improve GCS performance + context=ts.Context( + { + "cache_pool": {"total_bytes_limit": 0}, + "data_copy_concurrency": {"limit": "shared"}, + "gcs_request_concurrency": {"limit": 480}, + } + ), ) shape = tuple(t.shape if global_shape is None else global_shape) new_shard_shape = in_sharding.shard_shape(shape) From 22b8fafd9dac31650ce44d6b45f169f6fe34ceef Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 14:10:51 -0700 Subject: [PATCH 24/50] non blocking device put, only block when all device puts are done --- axlearn/common/array_serialization.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 33c50120f..024ea83c9 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -350,7 +350,9 @@ async def _run_serializer( def _blocking_device_put(out: Tensor, layout: Layout) -> Tensor: - return jax.block_until_ready(jax.device_put(out, layout)) + # Make it non blocking + # return jax.block_until_ready(jax.device_put(out, layout)) + return jax.device_put(out, layout) async def _async_deserialize( @@ -409,6 +411,8 @@ async def _async_deserialize( if tensorstore_spec.get("kvstore", {}).get("driver", "") == "gcs": tensorstore_spec["kvstore"]["driver"] = "gcs_grpc" + logging.info("tensorstore_spec: %s", tensorstore_spec) + t = await ts.open( tensorstore_spec, open=True, @@ -643,6 +647,7 @@ async def _run_deserializer(): fut = asyncio.run_coroutine_threadsafe(_run_deserializer(), self._loop) result = fut.result() + jax.block_until_ready(result) logging.info("deserialize took %.4f seconds.", time.time() - start_time) jax.profiler.stop_trace() sys.exit(0) From a129ef3807b370f4e35a2fad11228d5a2a60d29b Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 14:20:32 -0700 Subject: [PATCH 25/50] time the download from GCS --- axlearn/common/array_serialization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 024ea83c9..f07df8ffe 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -451,9 +451,11 @@ async def cb(index: array.Index, device: jax.Device): # the extra values will be filled with 0s. out = np.zeros(new_shard_shape, read_ts.dtype.numpy_dtype) + write_start_time = time.time() await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write( read_ts ) + logging.info("ts.array.write took %.4f seconds.", time.time() - write_start_time) # Convert to jnp array so that layouts are initialized properly for # sub-byte dtypes. From 73654f2fd399cfb1149b7356d741b49c3a1fed1a Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 14:25:03 -0700 Subject: [PATCH 26/50] add scripts to launch fuji within interactive pathways cluster --- restore_fuji_70b.sh | 8 ++++++++ train_fuji_7b.sh | 7 +++++++ 2 files changed, 15 insertions(+) create mode 100644 restore_fuji_70b.sh create mode 100644 train_fuji_7b.sh diff --git a/restore_fuji_70b.sh b/restore_fuji_70b.sh new file mode 100644 index 000000000..08bd5b352 --- /dev/null +++ b/restore_fuji_70b.sh @@ -0,0 +1,8 @@ +python3 -m axlearn.common.launch_trainer_main \ + --module=text.gpt.c4_trainer \ + --config=fuji-70B-v3-flash \ + --trainer_dir=gs://cloud-tpu-multipod-dev-uss1/axlearn-fuji-v3-70b/ \ + --data_dir=gs://axlearn-public/tensorflow_datasets \ + --jax_backend=proxy \ + --mesh_selector=tpu-v5litepod-32-1 \ + --trace_at_steps=11 diff --git a/train_fuji_7b.sh b/train_fuji_7b.sh new file mode 100644 index 000000000..68c1f2d65 --- /dev/null +++ b/train_fuji_7b.sh @@ -0,0 +1,7 @@ +python3 -m axlearn.common.launch_trainer_main \ + --module=text.gpt.c4_trainer \ + --config=fuji-7B-v2-flash \ + --trainer_dir=gs://cloud-tpu-multipod-dev-uss1/axlearn-fuji-v2-7b/ \ + --data_dir=gs://axlearn-public/tensorflow_datasets \ + --jax_backend=proxy \ + --mesh_selector=tpu-v5litepod-32-1 From 2033d655e5008be0931006399e8075ce3982f05c Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 15:19:29 -0700 Subject: [PATCH 27/50] add pathways premap buffer 17gb --- axlearn/cloud/gcp/pathways_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 986d5ecaf..12b49ce95 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -537,6 +537,9 @@ def _build_pathways_worker_container( f"--resource_manager_address={pathways_head_address}:" + f"{_PATHWAYS_RESOURCE_MANAGER_PORT}", f"--gcs_scratch_location={cfg.output_dir}/pathways-staging", + # Set premap buffer to 17GB, needed for faster jax.device_put h2d + # pylint: disable=line-too-long + "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=17179869184", ] mega_scale_args = xla_flags_from_options(self._mxla_options).split() worker_container["args"].extend(mega_scale_args) From 67127b30e784fa6b946d7a57c4cb0b3f4c8a4d01 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 15:24:44 -0700 Subject: [PATCH 28/50] Generate unique subdir for profile --- axlearn/common/array_serialization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index f07df8ffe..516b02b30 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -23,6 +23,7 @@ import sys import threading import time +import uuid from collections import defaultdict from concurrent import futures from concurrent.futures import ThreadPoolExecutor @@ -623,7 +624,8 @@ def deserialize( logging.info("concurrent_gb=%s GB.", concurrent_gb) self.wait_until_finished() start_time = time.time() - jax.profiler.start_trace("gs://cloud-tpu-multipod-dev-uss1/stoelinga-profile-1/") + uid = uuid.uuid4() + jax.profiler.start_trace(f"gs://cloud-tpu-multipod-dev-uss1/stoelinga-{uid}/") concurrent_bytes = concurrent_gb * 10**9 From 976dcd374d4858e674721b11fceb7721c97e0b57 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 15:41:48 -0700 Subject: [PATCH 29/50] pathways bump async computations --- axlearn/cloud/gcp/pathways_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 12b49ce95..fd14a9054 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -540,6 +540,7 @@ def _build_pathways_worker_container( # Set premap buffer to 17GB, needed for faster jax.device_put h2d # pylint: disable=line-too-long "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=17179869184", + "--temporary_flags_for_debugging=temporary_flag_for_debugging_pathways_xla_max_inflight_async_computations=1000", ] mega_scale_args = xla_flags_from_options(self._mxla_options).split() worker_container["args"].extend(mega_scale_args) From 0c5f5071990ae765956e1bd478b6cce27b861be2 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 15:51:35 -0700 Subject: [PATCH 30/50] comment out flag that didnt work --- axlearn/cloud/gcp/pathways_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index fd14a9054..a3ab2b01c 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -540,7 +540,8 @@ def _build_pathways_worker_container( # Set premap buffer to 17GB, needed for faster jax.device_put h2d # pylint: disable=line-too-long "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=17179869184", - "--temporary_flags_for_debugging=temporary_flag_for_debugging_pathways_xla_max_inflight_async_computations=1000", + # Causes crash on cloud + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_pathways_xla_max_inflight_async_computations=1000", ] mega_scale_args = xla_flags_from_options(self._mxla_options).split() worker_container["args"].extend(mega_scale_args) From 0c0cd775ed24aa16c370171a5ffe5c020c286ac1 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 16:07:04 -0700 Subject: [PATCH 31/50] proper pathways flag prefix --- axlearn/cloud/gcp/pathways_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index a3ab2b01c..b92b43bab 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -541,7 +541,7 @@ def _build_pathways_worker_container( # pylint: disable=line-too-long "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=17179869184", # Causes crash on cloud - # "--temporary_flags_for_debugging=temporary_flag_for_debugging_pathways_xla_max_inflight_async_computations=1000", + "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000", ] mega_scale_args = xla_flags_from_options(self._mxla_options).split() worker_container["args"].extend(mega_scale_args) From a98c2a2ae072193fcbe69aa88ab7db78bb935c4e Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 16:42:36 -0700 Subject: [PATCH 32/50] test pathways flags --- axlearn/cloud/gcp/pathways_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index b92b43bab..9e25bfebe 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -539,9 +539,12 @@ def _build_pathways_worker_container( f"--gcs_scratch_location={cfg.output_dir}/pathways-staging", # Set premap buffer to 17GB, needed for faster jax.device_put h2d # pylint: disable=line-too-long - "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=17179869184", - # Causes crash on cloud - "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000", + # Below flags did not help on 7b restore time + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_pinned_host_allocation_mode=recycle", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_tpu_allow_async_allocations=true", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_num_premapped_partitions=16", ] mega_scale_args = xla_flags_from_options(self._mxla_options).split() worker_container["args"].extend(mega_scale_args) From f8adbccfd28fd8d07e0d86b5fd606ab987eefa58 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 12 Sep 2025 22:13:39 -0700 Subject: [PATCH 33/50] make the proxy bench more similar to axlearn --- proxy_bench.py | 64 +++++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/proxy_bench.py b/proxy_bench.py index ca5312d34..4dc10dee4 100644 --- a/proxy_bench.py +++ b/proxy_bench.py @@ -1,51 +1,37 @@ """A script to benchmark JAX device_put throughput.""" +import asyncio import os import time import jax import numpy as np import pathwaysutils -from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.sharding import SingleDeviceSharding -def benchmark_host_to_device_throughput(device_put_buffer_mb: int = 1024): +async def benchmark_host_to_device_throughput( + device_put_buffer_mb: int = 512, num_transfers: int = 10 +): """Benchmarks JAX device_put throughput from CPU host to a v5e-32 TPU slice.""" print(f"JAX version: {jax.__version__}") devices = jax.devices() if os.environ.get("JAX_PLATFORMS") else jax.local_devices() num_devices = len(devices) print(f"Available devices: {num_devices}") - data_bytes_per_device = int(device_put_buffer_mb * 1024 * 1024) # 1 GiB + data_bytes_per_device = int(device_put_buffer_mb * 1024 * 1024) dtype = np.float32 num_elements = data_bytes_per_device // np.dtype(dtype).itemsize data_gb_per_device = num_elements * np.dtype(dtype).itemsize / (1024**3) print( - f"Creating a NumPy array of shape ({num_elements},) type {dtype}, size" - f" {data_gb_per_device:.2f} GiB" + f"Creating {num_devices} NumPy arrays of shape ({num_elements},) type {dtype}, size" + f" {data_gb_per_device:.2f} GiB each" ) - host_array = np.arange(num_elements, dtype=dtype) - - # Create a mesh spanning all devices. - mesh = Mesh(np.array(devices), axis_names=("i",)) - # An empty PartitionSpec() means the array is fully replicated across all - # devices in the mesh. - replicated_sharding = NamedSharding(mesh, PartitionSpec()) - print(f"Using sharding for replication: {replicated_sharding}") - - # Warm-up transfer - print("Performing warm-up transfer...") - try: - dummy_array = jax.device_put(host_array, replicated_sharding) - dummy_array.block_until_ready() - print("Warm-up complete.") - except RuntimeError as e: - print(f"Error during warm-up: {e}") - return - - # Benchmark loop - num_transfers = 5 + host_arrays = [np.arange(num_elements, dtype=dtype) for _ in range(num_devices)] + shardings = [SingleDeviceSharding(device) for device in devices] + + loop = asyncio.get_running_loop() transfer_times = [] print(f"Starting benchmark ({num_transfers} transfers)...") @@ -53,26 +39,40 @@ def benchmark_host_to_device_throughput(device_put_buffer_mb: int = 1024): if i == 0: trace_dir = "gs://cloud-tpu-multipod-dev-axlearn/stoelinga-proxy-benchmark" jax.profiler.start_trace(f"{trace_dir}/{device_put_buffer_mb}mb") + start_time = time.perf_counter() - device_array = jax.device_put(host_array, replicated_sharding) - device_array.block_until_ready() + + # Issue device_put calls in parallel. + device_put_futures = [ + loop.run_in_executor(None, jax.device_put, host_arrays[j], shardings[j]) + for j in range(num_devices) + ] + device_arrays = await asyncio.gather(*device_put_futures) + + # Block until all transfers are complete. + for device_array in device_arrays: + device_array.block_until_ready() + end_time = time.perf_counter() duration = end_time - start_time transfer_times.append(duration) print(f"Transfer {i+1}/{num_transfers}: {duration:.4f} seconds") + if i == 0: jax.profiler.stop_trace() - del device_array # Optional: hint for early deletion + + # Optional: hint for early deletion. + del device_arrays avg_time = np.mean(transfer_times) - print(f"\nAverage time per device_put call: {avg_time:.4f} seconds") + print(f"\nAverage time per parallel device_put batch: {avg_time:.4f} seconds") total_data_moved_gb = data_gb_per_device * num_devices throughput_gb_s = total_data_moved_gb / avg_time print(f"Data per device: {data_gb_per_device:.2f} GiB") - print("Total data transferred from host per operation:" f" {total_data_moved_gb:.2f} GiB") + print(f"Total data transferred from host per operation: {total_data_moved_gb:.2f} GiB") print(f"Aggregated Host -> Devices Throughput: {throughput_gb_s:.2f} GiB/s") print(f"Aggregated Host -> Devices Throughput: {throughput_gb_s * 8:.2f} Gbps/s") @@ -85,4 +85,4 @@ def benchmark_host_to_device_throughput(device_put_buffer_mb: int = 1024): scenarios_mb = [1, 128, 1024, 2048] for scenario in scenarios_mb: print(f"Running scenario {scenario}MB") - benchmark_host_to_device_throughput(scenario) + asyncio.run(benchmark_host_to_device_throughput(scenario)) From d1c38bcee5600bc3725864c159ee677e632865e4 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sat, 13 Sep 2025 14:24:26 -0700 Subject: [PATCH 34/50] rerun with premap buffer set --- axlearn/cloud/gcp/pathways_utils.py | 5 +++-- axlearn/common/array_serialization.py | 1 + proxy_bench.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 9e25bfebe..74b13b472 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -538,9 +538,10 @@ def _build_pathways_worker_container( + f"{_PATHWAYS_RESOURCE_MANAGER_PORT}", f"--gcs_scratch_location={cfg.output_dir}/pathways-staging", # Set premap buffer to 17GB, needed for faster jax.device_put h2d - # pylint: disable=line-too-long + # "--pathways_tpu_premapped_buffer_size=17179869184" doesn't work in cloud # Below flags did not help on 7b restore time - # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", + # pylint: disable=line-too-long + "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000", # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_pinned_host_allocation_mode=recycle", # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_tpu_allow_async_allocations=true", diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 516b02b30..0c611b8b9 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -470,6 +470,7 @@ async def cb(index: array.Index, device: jax.Device): mb_256 = 256 * 1024 * 1024 out_size = math.ceil(out_size / mb_256) * mb_256 + logging.info("in_sharding: %s", in_sharding) layout = Layout( dll, jax.sharding.SingleDeviceSharding(device, memory_kind=in_sharding.memory_kind) ) diff --git a/proxy_bench.py b/proxy_bench.py index 4dc10dee4..a3c95b8f2 100644 --- a/proxy_bench.py +++ b/proxy_bench.py @@ -37,7 +37,7 @@ async def benchmark_host_to_device_throughput( print(f"Starting benchmark ({num_transfers} transfers)...") for i in range(num_transfers): if i == 0: - trace_dir = "gs://cloud-tpu-multipod-dev-axlearn/stoelinga-proxy-benchmark" + trace_dir = "gs://cloud-tpu-multipod-dev-axlearn/stoelinga-proxy-benchmark-premap" jax.profiler.start_trace(f"{trace_dir}/{device_put_buffer_mb}mb") start_time = time.perf_counter() From 4cd54d541e5cb97fec9ddc2ee8d4f7608a13a286 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sat, 13 Sep 2025 16:39:58 -0700 Subject: [PATCH 35/50] mess around with pathways flags --- axlearn/cloud/gcp/pathways_utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 74b13b472..e30247e11 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -540,12 +540,14 @@ def _build_pathways_worker_container( # Set premap buffer to 17GB, needed for faster jax.device_put h2d # "--pathways_tpu_premapped_buffer_size=17179869184" doesn't work in cloud # Below flags did not help on 7b restore time + # Recycle vs on-demand seems to give a slight perf boost + "--tpu_pinned_host_allocation_recycle=true", # pylint: disable=line-too-long - "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", + # "--temporary_flags_for_debugging=temporary_flag_for_debuggings_max_num_threads_for_xla_compilation=1000" # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000", - # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_pinned_host_allocation_mode=recycle", # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_tpu_allow_async_allocations=true", - # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_num_premapped_partitions=16", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_num_premapped_partitions=65536", ] mega_scale_args = xla_flags_from_options(self._mxla_options).split() worker_container["args"].extend(mega_scale_args) @@ -909,6 +911,14 @@ def _build_pathways_rm_container(self) -> dict: "--instance_count=1", f"--instance_type={pathways_tpu_version}:{system.topology}", f"--gcs_scratch_location={staging_location}", + # Troubleshooting perf + "--tpu_pinned_host_allocation_recycle=true", + # pylint: disable=line-too-long + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", + # "--temporary_flags_for_debugging=temporary_flag_for_debuggings_max_num_threads_for_xla_compilation=1000" + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_tpu_allow_async_allocations=true", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_num_premapped_partitions=65536", ], ports=[dict(containerPort=_PATHWAYS_RESOURCE_MANAGER_PORT)], ) From 9c245b453aff840a33e61ede1091d53d8b4aff4d Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sat, 13 Sep 2025 20:42:36 -0700 Subject: [PATCH 36/50] pathways head on TPU VM --- axlearn/cloud/gcp/pathways_utils.py | 74 +++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 15 deletions(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index e30247e11..446a2d367 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -146,12 +146,14 @@ class Config(BaseReplicatedJob.Config): inner: The wrapped TPUReplicatedJob configuration. pathways_head_cpu: CPU request for pathways-head container. pathways_head_mem: Memory request for pathways-head container. + pathways_head_on_tpu: Whether to run pathways head on TPU VM. """ inner: Required[TPUReplicatedJob.Config] = REQUIRED pathways_xla_flags: list[str] = [] pathways_head_cpu: Optional[str] = None pathways_head_mem: Optional[str] = None + pathways_head_on_tpu: bool = False @classmethod def define_flags(cls, fv): @@ -180,6 +182,12 @@ def define_flags(cls, fv): "Memory request for pathways-head container in GiB. Default is 16GiB", **common_kwargs, ) + flags.DEFINE_boolean( + "pathways_head_on_tpu", + False, + "If True, run pathways head on TPU VM.", + **common_kwargs, + ) @classmethod def set_defaults(cls, fv): @@ -414,9 +422,15 @@ def _build_pathways_head_pod(self) -> Nested[Any]: } ) - node_selector = { - _PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY: _PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE, - } + if self.config.pathways_head_on_tpu: + # pylint: disable-next=protected-access + pod = self._inner._build_pod() + node_selector = {} + tolerations = pod["spec"]["tolerations"] + else: + node_selector = { + _PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY: _PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE, + } head_container = self._build_pathways_head_container() init_containers = [ @@ -437,6 +451,26 @@ def _build_pathways_head_pod(self) -> Nested[Any]: "hostAliases": [metadata_host_alias], "nodeSelector": node_selector, "tolerations": tolerations, + "affinity": { + "podAffinity": { + "requiredDuringSchedulingIgnoredDuringExecution": [ + { + "labelSelector": { + "matchExpressions": [ + { + "key": "batch.kubernetes.io/job-name", + "operator": "In", + "values": [ + f"{cfg.name}-{_PATHWAYS_WORKER_REPLICATED_JOB_NAME}-0" + ], + } + ] + }, + "topologyKey": "kubernetes.io/hostname", + } + ] + } + }, "containers": [head_container], "initContainers": init_containers, "volumes": volumes, @@ -445,6 +479,11 @@ def _build_pathways_head_pod(self) -> Nested[Any]: "dnsPolicy": "ClusterFirstWithHostNet", } + # Remove host ports to avoid scheduling conflicts on the same node. + # The pod runs on host network anyway, so the ports are still accessible. + if "ports" in head_pod_spec["containers"][0]: + del head_pod_spec["containers"][0]["ports"] + if cfg.priority_class: head_pod_spec["priorityClassName"] = cfg.priority_class @@ -645,18 +684,23 @@ def _build_pathways_worker_job( def __call__(self) -> Sequence[Nested[Any]]: cfg: TPUReplicatedJob.Config = self._inner.config - replicated_jobs = [ - dict( - name=_PATHWAYS_HEAD_REPLICATED_JOB_NAME, - replicas=1, - template=self._build_pathways_head_job(), - ), - dict( - name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME, - replicas=cfg.accelerator.num_replicas, - template=self._build_pathways_worker_job(), - ), - ] + worker_job = dict( + name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME, + replicas=cfg.accelerator.num_replicas, + template=self._build_pathways_worker_job(), + ) + head_job = dict( + name=_PATHWAYS_HEAD_REPLICATED_JOB_NAME, + replicas=1, + template=self._build_pathways_head_job(), + ) + if self.config.pathways_head_on_tpu: + head_job["dependsOn"] = [ + dict(name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME, status="Ready") + ] + replicated_jobs = [worker_job, head_job] + else: + replicated_jobs = [head_job, worker_job] return replicated_jobs From 8dc5205d902fe05ff70858a37ec7c8d4975a9503 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sat, 13 Sep 2025 22:58:13 -0700 Subject: [PATCH 37/50] stick to default concurrent restore of 32gb --- axlearn/common/array_serialization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 0c611b8b9..3f57bc1fc 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -618,10 +618,10 @@ def deserialize( tensorstore_specs: Sequence[dict[str, Any]], global_shapes: Optional[Sequence[array.Shape]] = None, dtypes: Optional[Sequence[typing.DTypeLike]] = None, - concurrent_gb: int = 128, + concurrent_gb: int = 32, ): # force to 128 - concurrent_gb = max(128, concurrent_gb) + # concurrent_gb = max(128, concurrent_gb) logging.info("concurrent_gb=%s GB.", concurrent_gb) self.wait_until_finished() start_time = time.time() From 6bae63d29ffc19ba88600f4da28ee417cac9bc7c Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sat, 13 Sep 2025 23:00:23 -0700 Subject: [PATCH 38/50] switch to standard blocking device put --- axlearn/common/array_serialization.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 3f57bc1fc..78bd9b08d 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -352,8 +352,8 @@ async def _run_serializer( def _blocking_device_put(out: Tensor, layout: Layout) -> Tensor: # Make it non blocking - # return jax.block_until_ready(jax.device_put(out, layout)) - return jax.device_put(out, layout) + # return jax.device_put(out, layout) + return jax.block_until_ready(jax.device_put(out, layout)) async def _async_deserialize( @@ -652,7 +652,8 @@ async def _run_deserializer(): fut = asyncio.run_coroutine_threadsafe(_run_deserializer(), self._loop) result = fut.result() - jax.block_until_ready(result) + # Only needed when we use non blocking device put + # jax.block_until_ready(result) logging.info("deserialize took %.4f seconds.", time.time() - start_time) jax.profiler.stop_trace() sys.exit(0) From e38d34ff9a7c69dd15f04b779a5b724a50b30b9e Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sat, 13 Sep 2025 23:09:04 -0700 Subject: [PATCH 39/50] concurrent restore 64GB --- axlearn/common/array_serialization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 78bd9b08d..e27a87995 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -620,8 +620,8 @@ def deserialize( dtypes: Optional[Sequence[typing.DTypeLike]] = None, concurrent_gb: int = 32, ): - # force to 128 - # concurrent_gb = max(128, concurrent_gb) + # force to 64 + concurrent_gb = max(64, concurrent_gb) logging.info("concurrent_gb=%s GB.", concurrent_gb) self.wait_until_finished() start_time = time.time() From 032dd7bf249745a18845cb73a91fb5c87666a3a1 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sat, 13 Sep 2025 23:11:19 -0700 Subject: [PATCH 40/50] disable force to 64 --- axlearn/common/array_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index e27a87995..3993de4a6 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -621,7 +621,7 @@ def deserialize( concurrent_gb: int = 32, ): # force to 64 - concurrent_gb = max(64, concurrent_gb) + # concurrent_gb = max(64, concurrent_gb) logging.info("concurrent_gb=%s GB.", concurrent_gb) self.wait_until_finished() start_time = time.time() From 81756e5abb584f0448feec154f75f2103fd978a1 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sun, 14 Sep 2025 11:47:12 -0700 Subject: [PATCH 41/50] re-enable premap buffer --- axlearn/cloud/gcp/pathways_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 446a2d367..123a7b104 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -582,7 +582,7 @@ def _build_pathways_worker_container( # Recycle vs on-demand seems to give a slight perf boost "--tpu_pinned_host_allocation_recycle=true", # pylint: disable=line-too-long - # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", + "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", # "--temporary_flags_for_debugging=temporary_flag_for_debuggings_max_num_threads_for_xla_compilation=1000" # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000", # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_tpu_allow_async_allocations=true", @@ -958,7 +958,7 @@ def _build_pathways_rm_container(self) -> dict: # Troubleshooting perf "--tpu_pinned_host_allocation_recycle=true", # pylint: disable=line-too-long - # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", + "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", # "--temporary_flags_for_debugging=temporary_flag_for_debuggings_max_num_threads_for_xla_compilation=1000" # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000", # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_tpu_allow_async_allocations=true", From a4f2ba63488870929a16e4da9ae8d26a90da9f9f Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 15 Sep 2025 11:40:28 -0700 Subject: [PATCH 42/50] set cpu nodepool selector to c4 --- axlearn/cloud/gcp/pathways_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 123a7b104..47e463789 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -67,8 +67,8 @@ _PATHWAYS_WORKER_REPLICATED_JOB_NAME = "pathways-worker" # Add node-selector for cpu workload to avoid sharing nodes with system services. -_PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY = "axlearn/nodepool_type" -_PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE = "workload" +_PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY = "node.kubernetes.io/instance-type" +_PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE = "c4-standard-192" # The back off limit of pathways pods. # Note that the head pod will back of exact this many times. # While workers will share #workers * _PATHWAYS_BACK_OFF_LIMIT total times. From 0471d3c6a6e1780479f8eb32251584e8115bacd1 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 15 Sep 2025 13:15:14 -0700 Subject: [PATCH 43/50] Revert "set cpu nodepool selector to c4" This reverts commit a4f2ba63488870929a16e4da9ae8d26a90da9f9f. --- axlearn/cloud/gcp/pathways_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 47e463789..123a7b104 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -67,8 +67,8 @@ _PATHWAYS_WORKER_REPLICATED_JOB_NAME = "pathways-worker" # Add node-selector for cpu workload to avoid sharing nodes with system services. -_PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY = "node.kubernetes.io/instance-type" -_PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE = "c4-standard-192" +_PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY = "axlearn/nodepool_type" +_PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE = "workload" # The back off limit of pathways pods. # Note that the head pod will back of exact this many times. # While workers will share #workers * _PATHWAYS_BACK_OFF_LIMIT total times. From 09c5bff4f41f073b9f51eb063a7fbea9296ada6c Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 15 Sep 2025 14:00:11 -0700 Subject: [PATCH 44/50] fix training script --- train_fuji_7b.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train_fuji_7b.sh b/train_fuji_7b.sh index 68c1f2d65..b0af00997 100644 --- a/train_fuji_7b.sh +++ b/train_fuji_7b.sh @@ -1,7 +1,10 @@ + +# --trainer_dir=gs://cloud-tpu-multipod-dev-uss1/axlearn-fuji-v2-7b/ \ + python3 -m axlearn.common.launch_trainer_main \ --module=text.gpt.c4_trainer \ --config=fuji-7B-v2-flash \ - --trainer_dir=gs://cloud-tpu-multipod-dev-uss1/axlearn-fuji-v2-7b/ \ + --trainer_dir=gs://cloud-tpu-multipod-dev-euw4/axlearn-fuji-v2-7b/ \ --data_dir=gs://axlearn-public/tensorflow_datasets \ --jax_backend=proxy \ --mesh_selector=tpu-v5litepod-32-1 From 47e2e7425f813b88089e5cc9501514912ed9d4df Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 15 Sep 2025 14:00:26 -0700 Subject: [PATCH 45/50] fix pathways_head_on_tpu=false --- axlearn/cloud/gcp/pathways_utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 123a7b104..e3530f1c7 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -451,7 +451,15 @@ def _build_pathways_head_pod(self) -> Nested[Any]: "hostAliases": [metadata_host_alias], "nodeSelector": node_selector, "tolerations": tolerations, - "affinity": { + "containers": [head_container], + "initContainers": init_containers, + "volumes": volumes, + "serviceAccountName": cfg.service_account, + "hostNetwork": True, + "dnsPolicy": "ClusterFirstWithHostNet", + } + if self.config.pathways_head_on_tpu: + head_pod_spec["affinity"] = { "podAffinity": { "requiredDuringSchedulingIgnoredDuringExecution": [ { @@ -470,14 +478,7 @@ def _build_pathways_head_pod(self) -> Nested[Any]: } ] } - }, - "containers": [head_container], - "initContainers": init_containers, - "volumes": volumes, - "serviceAccountName": cfg.service_account, - "hostNetwork": True, - "dnsPolicy": "ClusterFirstWithHostNet", - } + } # Remove host ports to avoid scheduling conflicts on the same node. # The pod runs on host network anyway, so the ports are still accessible. From d617723b4f117e17b654f60427fdf2b352e06698 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 15 Sep 2025 14:37:09 -0700 Subject: [PATCH 46/50] uniqe id for each xprof --- proxy_bench.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/proxy_bench.py b/proxy_bench.py index a3c95b8f2..1f89a84c4 100644 --- a/proxy_bench.py +++ b/proxy_bench.py @@ -3,6 +3,7 @@ import asyncio import os import time +import uuid import jax import numpy as np @@ -36,8 +37,9 @@ async def benchmark_host_to_device_throughput( print(f"Starting benchmark ({num_transfers} transfers)...") for i in range(num_transfers): + uid = uuid.uuid4() if i == 0: - trace_dir = "gs://cloud-tpu-multipod-dev-axlearn/stoelinga-proxy-benchmark-premap" + trace_dir = f"gs://cloud-tpu-multipod-dev-axlearn/{uid}" jax.profiler.start_trace(f"{trace_dir}/{device_put_buffer_mb}mb") start_time = time.perf_counter() From 08718bd7c9d9653a56a54a1deb3410de3078fb56 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 15 Sep 2025 15:38:34 -0700 Subject: [PATCH 47/50] use privileged to get rid of zero copy warning --- axlearn/cloud/gcp/pathways_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index e3530f1c7..e3afbc3a5 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -362,6 +362,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: dict( name=_PATHWAYS_PROXY_CONTAINER_NAME, image=_PATHWAYS_PROXY_IMAGE, + securityContext={"privileged": True}, # https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers # SideCar container is an init container with restartPolicy as "Always". restartPolicy="Always", @@ -921,6 +922,7 @@ def _build_pathways_proxy_container(self) -> dict: return dict( name=_PATHWAYS_PROXY_CONTAINER_NAME, image=_PATHWAYS_PROXY_IMAGE, + securityContext={"privileged": True}, args=[ f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}", f"--server_port={_PATHWAYS_PROXY_PORT}", From cf24047fd80e8aa61dc8b1514f58a90e2290a55e Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 16 Sep 2025 10:35:41 -0700 Subject: [PATCH 48/50] wip: unix socket pathways proxy --- axlearn/cloud/gcp/pathways_utils.py | 30 ++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index e3afbc3a5..098edbd1d 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -48,14 +48,19 @@ # There is no guarantee that this image will work with newer Jax releases. # This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625 # This image extends GRPC timeout for long context models. -_PATHWAYS_IMAGE_TAG = "disable_settings_20250701" +# _PATHWAYS_IMAGE_TAG = "disable_settings_20250701" +_PATHWAYS_IMAGE_TAG = "uds" # The docker image used by pathways proxy container. _PATHWAYS_PROXY_IMAGE = ( - f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}" + # f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}" + "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/shauryag/" + f"unsanitized_proxy_server:{_PATHWAYS_IMAGE_TAG}" ) # The docker image used by pathways resource manager container and worker container. _PATHWAYS_SERVER_IMAGE = ( - f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}" + # f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}" + "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/shauryag/" + f"unsanitized_server:{_PATHWAYS_IMAGE_TAG}" ) # The container name of pathways resourcemanager. _PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm" @@ -269,10 +274,16 @@ def _build_pathways_head_container(self) -> dict: head_container = copy.deepcopy(container) env_list = head_container.get("env", []) + # self._update_env_list( + # env_list, + # "JAX_BACKEND_TARGET", + # f"grpc://localhost:{_PATHWAYS_PROXY_PORT}", + # ) + # Unix domain socket self._update_env_list( env_list, "JAX_BACKEND_TARGET", - f"grpc://localhost:{_PATHWAYS_PROXY_PORT}", + "grpc:///tmp/ifrt_proxy.sock", ) self._update_env_list(env_list, "XCLOUD_ENVIRONMENT", "GCP") self._update_env_list(env_list, "JAX_PLATFORMS", "proxy") @@ -327,6 +338,10 @@ def _build_pathways_head_container(self) -> dict: } head_container["resources"] = resources + volume_mounts = head_container.get("volumeMounts", []) + volume_mounts.append(dict(name="shared-memory", mountPath="/tmp/")) + head_container["volumeMounts"] = volume_mounts + return head_container def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: @@ -350,6 +365,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: cmd_args = [ f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}", + # using unix socket but port needs to be set anyway f"--server_port={_PATHWAYS_PROXY_PORT}", f"--gcs_scratch_location={staging_location}", ] @@ -374,7 +390,10 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: {"name": "XLA_FLAGS", "value": f"--xla_dump_to=/output/{cfg.name}/xla"}, ], ports=[dict(containerPort=_PATHWAYS_PROXY_PORT)], - volumeMounts=[dict(name="shared-output", mountPath="/output")], + volumeMounts=[ + dict(name="shared-output", mountPath="/output"), + dict(name="shared-memory", mountPath="/tmp/"), + ], ), dict( name=_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME, @@ -412,6 +431,7 @@ def _build_pathways_head_pod(self) -> Nested[Any]: labels.update({BASTION_JOB_VERSION_LABEL: os.environ.get(BASTION_JOB_VERSION_ENV_VAR)}) volumes.append(dict(name="shared-output", emptyDir={})) + volumes.append(dict(name="shared-memory", emptyDir=dict(medium="Memory"))) if cfg.gcsfuse_mount: annotations.update( From 6792b25b7675a4f04c0e518e51ff7415dab29063 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 16 Sep 2025 11:21:38 -0700 Subject: [PATCH 49/50] train fuji 7b on v5p / restore --- train_fuji_7b.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_fuji_7b.sh b/train_fuji_7b.sh index b0af00997..1d7822c41 100644 --- a/train_fuji_7b.sh +++ b/train_fuji_7b.sh @@ -7,4 +7,4 @@ python3 -m axlearn.common.launch_trainer_main \ --trainer_dir=gs://cloud-tpu-multipod-dev-euw4/axlearn-fuji-v2-7b/ \ --data_dir=gs://axlearn-public/tensorflow_datasets \ --jax_backend=proxy \ - --mesh_selector=tpu-v5litepod-32-1 + --mesh_selector=tpu-v5p From 22b1eb7e38bd5972a89bc4f3012ff5043b5f46e0 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 16 Sep 2025 11:23:04 -0700 Subject: [PATCH 50/50] 7b remove nodeSelector --- axlearn/experiments/text/gpt/fuji.py | 2 +- train_fuji_7b.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 3373f75f4..3b892f701 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -412,7 +412,7 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, save_every_n_steps=100, max_step=max_step, - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( # Step time: # v1 on tpu-v4-1024 (512 chips): 3.03s diff --git a/train_fuji_7b.sh b/train_fuji_7b.sh index 1d7822c41..06706944f 100644 --- a/train_fuji_7b.sh +++ b/train_fuji_7b.sh @@ -6,5 +6,5 @@ python3 -m axlearn.common.launch_trainer_main \ --config=fuji-7B-v2-flash \ --trainer_dir=gs://cloud-tpu-multipod-dev-euw4/axlearn-fuji-v2-7b/ \ --data_dir=gs://axlearn-public/tensorflow_datasets \ - --jax_backend=proxy \ - --mesh_selector=tpu-v5p + --jax_backend=proxy +# --mesh_selector=tpu-v5p