diff --git a/axlearn/cli/utils_test.py b/axlearn/cli/utils_test.py index 62de78053..17545d0ae 100644 --- a/axlearn/cli/utils_test.py +++ b/axlearn/cli/utils_test.py @@ -368,7 +368,7 @@ def test_subprocess_argv(self): absl_main(mock_args) self.assertEqual(1, len(mock_popen.call_args_list)) self.assertEqual((expected,), mock_popen.call_args[0]) - self.assertDictContainsSubset({"text": True}, mock_popen.call_args[1]) + self.assertTrue({"text": True}.items() <= mock_popen.call_args[1].items()) self.assertEqual(self.root_module, mock_popen.call_args[1]["env"]["AXLEARN_CLI_NAME"]) shell_cases = [ @@ -397,8 +397,8 @@ def test_subprocess_argv(self): absl_main(mock_args) self.assertEqual(1, len(mock_popen.call_args_list)) self.assertEqual((expected,), mock_popen.call_args[0]) - self.assertDictContainsSubset( - {"text": True, "shell": True}, mock_popen.call_args[1] + self.assertTrue( + {"text": True, "shell": True}.items() <= mock_popen.call_args[1].items() ) self.assertEqual( self.root_module, mock_popen.call_args[1]["env"]["AXLEARN_CLI_NAME"] diff --git a/axlearn/common/ops/_optimization_barrier.py b/axlearn/common/ops/_optimization_barrier.py index ad9c0448e..99f400f5c 100644 --- a/axlearn/common/ops/_optimization_barrier.py +++ b/axlearn/common/ops/_optimization_barrier.py @@ -5,7 +5,6 @@ from typing import Any import jax -from jax._src import ad_checkpoint # pylint: disable=protected-access @jax.custom_jvp @@ -94,7 +93,7 @@ def print_result(msg: str, scalars: Tensor): Returns: `pytree` transparently wrapped in an XLA optimization barrier. """ - return ad_checkpoint._optimization_barrier(pytree) # pylint: disable=protected-access + return jax.lax.optimization_barrier(pytree) @forward_optimization_barrier.defjvp diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index ad5bdf027..91d14fc32 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -14,7 +14,6 @@ # pylint: disable=no-self-use import jax -import jaxlib import numpy as np import pytest import tensorflow as tf @@ -1017,7 +1016,7 @@ def f(x): # With runtime_checks enabled, we should be able to crash with jittable checks without # needing to checkify. with runtime_checks(): - with self.assertRaisesRegex(jaxlib.xla_extension.XlaRuntimeError, "cannot be zero!"): + with self.assertRaisesRegex(jax.errors.JaxRuntimeError, "cannot be zero!"): jax.jit(f)(0) def test_prng_impl(self):