From 625ce3b4a41b699aa001b42d4fcd25e38ab2e19f Mon Sep 17 00:00:00 2001 From: Chang Lan Date: Thu, 24 Jul 2025 09:12:20 -0700 Subject: [PATCH 1/7] Fix running max calculation when logit sink is present (#1512) Make sure sink's contribution is added once. Also added tests. GitOrigin-RevId: 8de870cff544d5c933a1778904a97ab0777254bc --- .../common/flash_attention/tpu_attention_test.py | 4 ++-- .../common/flash_attention/tpu_splash_attention.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/axlearn/common/flash_attention/tpu_attention_test.py b/axlearn/common/flash_attention/tpu_attention_test.py index 335cd4f1a..9a38c0c5d 100644 --- a/axlearn/common/flash_attention/tpu_attention_test.py +++ b/axlearn/common/flash_attention/tpu_attention_test.py @@ -299,7 +299,7 @@ def test_logit_sink( # Compare outputs out = fn(input_batch) ref_out = ref_fn(input_batch) - self.assertNestedAllClose(out, ref_out, atol=2e-2) + self.assertNestedAllClose(out, ref_out, atol=1e-6 if q_dtype == jnp.float32 else 2e-2) # Compare gradients def grad_fn(float_inputs, aux_inputs, f): @@ -310,7 +310,7 @@ def grad_fn(float_inputs, aux_inputs, f): aux_inputs = dict(bias=bias, prng_key=prng_key) grad_out = jax.grad(grad_fn, argnums=0)(float_inputs, aux_inputs, fn) ref_grad_out = jax.grad(grad_fn, argnums=0)(float_inputs, aux_inputs, ref_fn) - self.assertNestedAllClose(grad_out, ref_grad_out, atol=1e-5) + self.assertNestedAllClose(grad_out, ref_grad_out, atol=1e-6) def test_logit_sink_shape_validation(self): """Test that logit sink shape validation works correctly.""" diff --git a/axlearn/common/flash_attention/tpu_splash_attention.py b/axlearn/common/flash_attention/tpu_splash_attention.py index 2c8f94f3d..be8b50567 100644 --- a/axlearn/common/flash_attention/tpu_splash_attention.py +++ b/axlearn/common/flash_attention/tpu_splash_attention.py @@ -25,9 +25,9 @@ participate in the max and sum computations but do not contribute to the output. When enabled, the `logit_sink` parameter provides per-head scalar values that are incorporated into the softmax normalization as follows: the running maximum is initialized with the sink value, and -at each step, the sink's contribution is added to the normalization sum (denominator) as -exp(logit_sink - running_max). The sink does not contribute to the numerator of the -attention-weighted sum, as it has no corresponding value. In the backward pass, gradients for +during the final normalization the sink's contribution is added once to the normalization sum +(denominator) as exp(logit_sink - running_max). The sink does not contribute to the numerator of +the attention-weighted sum, as it has no corresponding value. In the backward pass, gradients for the sink logits are computed as the negative sum of their attention weights multiplied by the output gradients, reflecting their role in the normalization term without direct output contribution. @@ -219,10 +219,6 @@ def body(kv_compute_index, _): assert s_curr.shape == (bq, bkv_compute) l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) - # Add sink contribution to normalization sum. - if logit_sink_ref is not None: - sink_value = logit_sink_ref[h].astype(qk.dtype) - l_curr = l_curr + jnp.exp(sink_value - m_next[:, 0:1]) assert l_curr.shape == (bq, NUM_LANES) alpha = jnp.exp(m_prev - m_next) @@ -262,6 +258,9 @@ def run(): @pl.when(j == grid_width - 1) def end(): l = l_scratch_ref[...] + if logit_sink_ref is not None: + sink_value = logit_sink_ref[h].astype(jnp.float32) + l = l + jnp.exp(sink_value - m_scratch_ref[...]) l_inv = pltpu.repeat(1.0 / l, head_dim_repeats, axis=1) o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype) if logsumexp_ref is not None: From 7f359b2975b2eadcc23a91935fda361e684dcca5 Mon Sep 17 00:00:00 2001 From: Chang Lan Date: Thu, 24 Jul 2025 12:33:29 -0700 Subject: [PATCH 2/7] Relax pyarrow deps GitOrigin-RevId: 56cf7e8eb1b6d848be84206a6ce666d73c983413 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e583bc0c4..03afb1635 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ core = [ "nltk==3.7", # for text preprocessing "optax==0.1.7", # optimizers (0.1.0 has known bugs). "portpicker", - "pyarrow>=20.0.0,<21.0.0", # Pin to v20.x to avoid PyExtensionType -> ExtensionType breaking change in v21 + "pyarrow<21.0.0", # Pin to v20.x to avoid PyExtensionType -> ExtensionType breaking change in v21 "protobuf>=3.20.3", "tensorboard-plugin-profile==2.20.4", # This has both x86 and arm64 wheels. Underneath the hood it uses tensorflow-macos since 2.13. @@ -126,7 +126,7 @@ vertexai_tensorboard = [ ] # Dataflow dependencies. dataflow = [ - "pyarrow>=20.0.0,<21.0.0", # Pin to v20.x to avoid PyExtensionType -> ExtensionType breaking change in v21 + "pyarrow<21.0.0", # Pin to v20.x to avoid PyExtensionType -> ExtensionType breaking change in v21 "apache-beam==2.55.1", "apache-beam[gcp]", "google-apitools", # for beam pipeline From 2e03b1229115b209f984ecb7a8791f125cdebf80 Mon Sep 17 00:00:00 2001 From: Ruixuan Hou Date: Fri, 25 Jul 2025 07:08:23 -0700 Subject: [PATCH 3/7] pin nccl version * pin nccl version * empty commit * add actual pacakge * trigger new build to address flaky test * Update pyproject.toml GitOrigin-RevId: b8653ad5e43a73cb027ae07dc73405d66b042203 --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 03afb1635..8e110d2f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,6 +137,8 @@ gpu = [ "triton==2.1.0", "jax[cuda12]==0.5.3", "nvidia-ml-py==12.560.30", + # pin nccl version, otherwise jax[cuda12] will pull latest version + "nvidia-nccl-cu12==2.27.5", ] # Open API inference. open_api = [ From 0560ff6fa706d23134d8d1f252ee81e6a4a6afef Mon Sep 17 00:00:00 2001 From: ReNothing Date: Fri, 25 Jul 2025 21:39:01 +0300 Subject: [PATCH 4/7] =?UTF-8?q?=D0=94=D0=BE=D0=B1=D0=B0=D0=B2=D0=B8=D1=82?= =?UTF-8?q?=D1=8C=20=D0=BF=D0=BE=D0=B4=D0=B4=D0=B5=D1=80=D0=B6=D0=BA=D1=83?= =?UTF-8?q?=20=D0=BA=D0=BB=D1=8E=D1=87=D0=B5=D0=B9=20=D0=B4=D0=BB=D1=8F=20?= =?UTF-8?q?pytree=20=D0=B2=20MetricAccumulator=20=D0=B8=20=D0=BE=D0=B1?= =?UTF-8?q?=D0=BD=D0=BE=D0=B2=D0=B8=D1=82=D1=8C=20=D1=84=D1=83=D0=BD=D0=BA?= =?UTF-8?q?=D1=86=D0=B8=D0=B8=20flatten/unflatten=20=D0=B2=20utils?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- axlearn/common/metrics.py | 6 ++++++ axlearn/common/struct.py | 7 ++++++- axlearn/common/utils.py | 19 +++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/axlearn/common/metrics.py b/axlearn/common/metrics.py index f1e05ba27..4820f87bb 100644 --- a/axlearn/common/metrics.py +++ b/axlearn/common/metrics.py @@ -111,3 +111,9 @@ def _metric_accumulator_unflatten( _metric_accumulator_flatten, _metric_accumulator_unflatten, ) + +jax.tree.register_pytree_node( + MetricAccumulator, + _metric_accumulator_flatten, + _metric_accumulator_unflatten, +) \ No newline at end of file diff --git a/axlearn/common/struct.py b/axlearn/common/struct.py index 798f223f0..cd2f64ec4 100644 --- a/axlearn/common/struct.py +++ b/axlearn/common/struct.py @@ -76,16 +76,21 @@ def flatten_func(x): return data, meta def flatten_with_keys(x) -> tuple[tuple, tuple]: - data = tuple((jax.tree_util.GetAttrKey(name), getattr(x, name)) for name in data_fields) + data = tuple((jax.tree.GetAttrKey(name), getattr(x, name)) for name in data_fields) meta = tuple(getattr(x, name) for name in meta_fields) return data, meta + # Note that meta, data are tuples as produced by `flatten_with_keys`. def unflatten_func(meta: tuple, data: tuple): # Support unflattening from chex.dataclass which requires handling lists. data = tuple(data) return dataklass(**dict(zip(meta_fields + data_fields, meta + data))) + jax.tree.register_pytree_with_keys( + dataklass, flatten_with_keys, unflatten_func, flatten_func + ) + jax.tree_util.register_pytree_with_keys( dataklass, flatten_with_keys, unflatten_func, flatten_func ) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index f8a9048d1..58be51454 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -88,6 +88,25 @@ # The set of supported floating point dtypes. _supported_float_dtypes = [jnp.bfloat16, jnp.float32] +@staticmethod +def _tree_map(*args, **kwargs): + is_leaf = lambda x: isinstance(x, Summary) + return jax.tree.map(*args, **kwargs, is_leaf=is_leaf) + +def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]: + """Generate the (key, value) pairs for the immediate children of a pytree `node`.""" + flat = jax.tree.default_registry.flatten_one_level(node) + if flat is None: + return [] + + if isinstance(node, tuple) and hasattr(node, "_fields") and flat[1] == type(node): + return [(jax.tree.GetAttrKey(s), getattr(node, s)) for s in node._fields] + + key_children, _ = jax.tree.default_registry.flatten_one_level_with_keys(node) + if key_children: + return key_children + + return [(jax.tree.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])] @dataclasses.dataclass class HybridMeshShape: From 78833d4d63ae24527dc440b6e7920047efabb3ce Mon Sep 17 00:00:00 2001 From: "Ethan (Meng) Li" Date: Fri, 25 Jul 2025 10:34:21 -0700 Subject: [PATCH 5/7] Disable xla_tpu_ici_sdc_test_run_on_program_start by default GitOrigin-RevId: 65af8017215c962c58b4f6d98aabc21250285e78 --- axlearn/common/compiler_options.py | 5 +++-- axlearn/common/compiler_options_test.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/axlearn/common/compiler_options.py b/axlearn/common/compiler_options.py index 8f7a2bca7..63ebf21c9 100644 --- a/axlearn/common/compiler_options.py +++ b/axlearn/common/compiler_options.py @@ -334,8 +334,9 @@ def infer_xsc_compiler_options( xla_tpu_sdc_checker_alternate_megacore_cores=True, # XLA ICI SDC Checker flags: # N.B. ICI checker only runs once after first program compilation. - # Enable the interconnect checker on first program call. - xla_tpu_ici_sdc_test_run_on_program_start=True, + # Disable the interconnect checker by default as it is not meant for production run. + # In a job with 32k chips, disabling it reduced compilation time from 18mins to 15s. + xla_tpu_ici_sdc_test_run_on_program_start=False, # Max distance between send/recv neighbours. xla_tpu_ici_sdc_test_max_distance=1, # Number of repeated send/recv before checking for equivalence. diff --git a/axlearn/common/compiler_options_test.py b/axlearn/common/compiler_options_test.py index dfc1addc5..3a08ae836 100644 --- a/axlearn/common/compiler_options_test.py +++ b/axlearn/common/compiler_options_test.py @@ -75,7 +75,7 @@ def test_xsc_compiler_options(self): xla_tpu_sdc_check_halt_on_detection=False, xla_tpu_sdc_replicate_llo=True, xla_tpu_sdc_checker_alternate_megacore_cores=True, - xla_tpu_ici_sdc_test_run_on_program_start=True, + xla_tpu_ici_sdc_test_run_on_program_start=False, xla_tpu_ici_sdc_test_max_distance=1, xla_tpu_ici_sdc_test_pipeline_depth=4, xla_tpu_ici_sdc_test_buffer_size_chunks=32, From a0992122d02ec9451a5b7be35878202b2dd0c8a4 Mon Sep 17 00:00:00 2001 From: ReNothing Date: Fri, 25 Jul 2025 21:52:09 +0300 Subject: [PATCH 6/7] =?UTF-8?q?=D0=9E=D0=B1=D0=BD=D0=BE=D0=B2=D0=B8=D1=82?= =?UTF-8?q?=D1=8C=20=D1=84=D1=83=D0=BD=D0=BA=D1=86=D0=B8=D1=8E=20is=5Fleaf?= =?UTF-8?q?=20=D0=B2=20=5Ftree=5Fmap=20=D0=B4=D0=BB=D1=8F=20=D0=BF=D0=BE?= =?UTF-8?q?=D0=B4=D0=B4=D0=B5=D1=80=D0=B6=D0=BA=D0=B8=20VDict=20=D0=B8=20T?= =?UTF-8?q?ensor;=20=D1=83=D0=B1=D1=80=D0=B0=D1=82=D1=8C=20=D0=BB=D0=B8?= =?UTF-8?q?=D1=88=D0=BD=D0=B8=D0=B9=20=D0=BF=D1=80=D0=BE=D0=B1=D0=B5=D0=BB?= =?UTF-8?q?=20=D0=B2=20=5Fenable=5Fnumeric=5Fchecks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- axlearn/common/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 58be51454..e1fc70d4f 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -82,7 +82,7 @@ # We avoid subscripting Sequence[int] so it can be used for isinstance checks. MeshShape = Sequence -_enable_numeric_checks = False +_enable_numeric_checks = False _enable_xla_runtime_errors = False # The set of supported floating point dtypes. @@ -90,7 +90,7 @@ @staticmethod def _tree_map(*args, **kwargs): - is_leaf = lambda x: isinstance(x, Summary) + is_leaf = lambda x: isinstance(x, VDict) or isinstance(x, Tensor) return jax.tree.map(*args, **kwargs, is_leaf=is_leaf) def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]: From a276afcc307afd2e3b09571347f863e53b2ad9c5 Mon Sep 17 00:00:00 2001 From: ReNothing Date: Sat, 26 Jul 2025 10:08:52 +0300 Subject: [PATCH 7/7] =?UTF-8?q?=D0=A3=D0=B4=D0=B0=D0=BB=D0=B8=D1=82=D1=8C?= =?UTF-8?q?=20=D1=80=D0=B5=D0=B3=D0=B8=D1=81=D1=82=D1=80=D0=B0=D1=86=D0=B8?= =?UTF-8?q?=D1=8E=20pytree=20=D0=B4=D0=BB=D1=8F=20MetricAccumulator=20?= =?UTF-8?q?=D0=B8=20=D1=83=D0=B1=D1=80=D0=B0=D1=82=D1=8C=20=D0=BD=D0=B5?= =?UTF-8?q?=D0=B8=D1=81=D0=BF=D0=BE=D0=BB=D1=8C=D0=B7=D1=83=D0=B5=D0=BC?= =?UTF-8?q?=D1=8B=D0=B9=20=D0=BA=D0=BE=D0=B4=20=D0=B2=20dataclass?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- axlearn/common/metrics.py | 6 ------ axlearn/common/struct.py | 4 ---- axlearn/common/utils.py | 7 +------ 3 files changed, 1 insertion(+), 16 deletions(-) diff --git a/axlearn/common/metrics.py b/axlearn/common/metrics.py index 4820f87bb..7cf6cc38a 100644 --- a/axlearn/common/metrics.py +++ b/axlearn/common/metrics.py @@ -110,10 +110,4 @@ def _metric_accumulator_unflatten( MetricAccumulator, _metric_accumulator_flatten, _metric_accumulator_unflatten, -) - -jax.tree.register_pytree_node( - MetricAccumulator, - _metric_accumulator_flatten, - _metric_accumulator_unflatten, ) \ No newline at end of file diff --git a/axlearn/common/struct.py b/axlearn/common/struct.py index cd2f64ec4..08cc69017 100644 --- a/axlearn/common/struct.py +++ b/axlearn/common/struct.py @@ -91,10 +91,6 @@ def unflatten_func(meta: tuple, data: tuple): dataklass, flatten_with_keys, unflatten_func, flatten_func ) - jax.tree_util.register_pytree_with_keys( - dataklass, flatten_with_keys, unflatten_func, flatten_func - ) - def to_state_dict(x) -> utils.Nested[Any]: return {name: serialization.to_state_dict(getattr(x, name)) for name in data_fields} diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index e1fc70d4f..66f5bdc1b 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -82,17 +82,12 @@ # We avoid subscripting Sequence[int] so it can be used for isinstance checks. MeshShape = Sequence -_enable_numeric_checks = False +_enable_numeric_checks = False _enable_xla_runtime_errors = False # The set of supported floating point dtypes. _supported_float_dtypes = [jnp.bfloat16, jnp.float32] -@staticmethod -def _tree_map(*args, **kwargs): - is_leaf = lambda x: isinstance(x, VDict) or isinstance(x, Tensor) - return jax.tree.map(*args, **kwargs, is_leaf=is_leaf) - def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]: """Generate the (key, value) pairs for the immediate children of a pytree `node`.""" flat = jax.tree.default_registry.flatten_one_level(node)