From 1a9f4d6ea268c93446ff15da8d959c100bfb5ee7 Mon Sep 17 00:00:00 2001 From: Samuel Andersen Date: Tue, 19 Aug 2025 09:49:54 -0700 Subject: [PATCH 1/3] unit test changes --- axlearn/cli/utils_test.py | 6 +++--- axlearn/common/ops/_optimization_barrier.py | 3 +-- axlearn/common/utils_test.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/axlearn/cli/utils_test.py b/axlearn/cli/utils_test.py index 62de78053..17545d0ae 100644 --- a/axlearn/cli/utils_test.py +++ b/axlearn/cli/utils_test.py @@ -368,7 +368,7 @@ def test_subprocess_argv(self): absl_main(mock_args) self.assertEqual(1, len(mock_popen.call_args_list)) self.assertEqual((expected,), mock_popen.call_args[0]) - self.assertDictContainsSubset({"text": True}, mock_popen.call_args[1]) + self.assertTrue({"text": True}.items() <= mock_popen.call_args[1].items()) self.assertEqual(self.root_module, mock_popen.call_args[1]["env"]["AXLEARN_CLI_NAME"]) shell_cases = [ @@ -397,8 +397,8 @@ def test_subprocess_argv(self): absl_main(mock_args) self.assertEqual(1, len(mock_popen.call_args_list)) self.assertEqual((expected,), mock_popen.call_args[0]) - self.assertDictContainsSubset( - {"text": True, "shell": True}, mock_popen.call_args[1] + self.assertTrue( + {"text": True, "shell": True}.items() <= mock_popen.call_args[1].items() ) self.assertEqual( self.root_module, mock_popen.call_args[1]["env"]["AXLEARN_CLI_NAME"] diff --git a/axlearn/common/ops/_optimization_barrier.py b/axlearn/common/ops/_optimization_barrier.py index ad9c0448e..73a0ff887 100644 --- a/axlearn/common/ops/_optimization_barrier.py +++ b/axlearn/common/ops/_optimization_barrier.py @@ -94,8 +94,7 @@ def print_result(msg: str, scalars: Tensor): Returns: `pytree` transparently wrapped in an XLA optimization barrier. """ - return ad_checkpoint._optimization_barrier(pytree) # pylint: disable=protected-access - + return jax.lax.optimization_barrier(pytree) @forward_optimization_barrier.defjvp def forward_optimization_barrier_jvp(primals: tuple, tangents: tuple) -> tuple[Any, Any]: diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index ad5bdf027..53dd6a33e 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -1017,7 +1017,7 @@ def f(x): # With runtime_checks enabled, we should be able to crash with jittable checks without # needing to checkify. with runtime_checks(): - with self.assertRaisesRegex(jaxlib.xla_extension.XlaRuntimeError, "cannot be zero!"): + with self.assertRaisesRegex(jax.errors.JaxRuntimeError, "cannot be zero!"): jax.jit(f)(0) def test_prng_impl(self): From 62fc630b20ca0de47da8e9b35e2b6e4e08a0c677 Mon Sep 17 00:00:00 2001 From: Eric Shen Date: Thu, 28 Aug 2025 22:21:44 +0000 Subject: [PATCH 2/3] remove unused import --- axlearn/common/ops/_optimization_barrier.py | 2 -- axlearn/common/utils_test.py | 1 - 2 files changed, 3 deletions(-) diff --git a/axlearn/common/ops/_optimization_barrier.py b/axlearn/common/ops/_optimization_barrier.py index 73a0ff887..0dd57e03c 100644 --- a/axlearn/common/ops/_optimization_barrier.py +++ b/axlearn/common/ops/_optimization_barrier.py @@ -5,8 +5,6 @@ from typing import Any import jax -from jax._src import ad_checkpoint # pylint: disable=protected-access - @jax.custom_jvp @jax.custom_batching.custom_vmap # Must be wrapped in this before custom_jvp. diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index 53dd6a33e..91d14fc32 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -14,7 +14,6 @@ # pylint: disable=no-self-use import jax -import jaxlib import numpy as np import pytest import tensorflow as tf From 5f51b8d315c03f1906f1516245b726f6489a9c9b Mon Sep 17 00:00:00 2001 From: Eric Shen Date: Thu, 28 Aug 2025 23:11:39 +0000 Subject: [PATCH 3/3] minor change to fix pre-commit black issue --- axlearn/common/ops/_optimization_barrier.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/axlearn/common/ops/_optimization_barrier.py b/axlearn/common/ops/_optimization_barrier.py index 0dd57e03c..99f400f5c 100644 --- a/axlearn/common/ops/_optimization_barrier.py +++ b/axlearn/common/ops/_optimization_barrier.py @@ -6,6 +6,7 @@ import jax + @jax.custom_jvp @jax.custom_batching.custom_vmap # Must be wrapped in this before custom_jvp. # PyTrees are defined by whether they are registered, not based on their type. @@ -94,6 +95,7 @@ def print_result(msg: str, scalars: Tensor): """ return jax.lax.optimization_barrier(pytree) + @forward_optimization_barrier.defjvp def forward_optimization_barrier_jvp(primals: tuple, tangents: tuple) -> tuple[Any, Any]: """The JVP for `optimization_barrier`.