Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions axlearn/cli/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def test_subprocess_argv(self):
absl_main(mock_args)
self.assertEqual(1, len(mock_popen.call_args_list))
self.assertEqual((expected,), mock_popen.call_args[0])
self.assertDictContainsSubset({"text": True}, mock_popen.call_args[1])
self.assertTrue({"text": True}.items() <= mock_popen.call_args[1].items())
self.assertEqual(self.root_module, mock_popen.call_args[1]["env"]["AXLEARN_CLI_NAME"])

shell_cases = [
Expand Down Expand Up @@ -397,8 +397,8 @@ def test_subprocess_argv(self):
absl_main(mock_args)
self.assertEqual(1, len(mock_popen.call_args_list))
self.assertEqual((expected,), mock_popen.call_args[0])
self.assertDictContainsSubset(
{"text": True, "shell": True}, mock_popen.call_args[1]
self.assertTrue(
{"text": True, "shell": True}.items() <= mock_popen.call_args[1].items()
)
self.assertEqual(
self.root_module, mock_popen.call_args[1]["env"]["AXLEARN_CLI_NAME"]
Expand Down
3 changes: 1 addition & 2 deletions axlearn/common/ops/_optimization_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any

import jax
from jax._src import ad_checkpoint # pylint: disable=protected-access


@jax.custom_jvp
Expand Down Expand Up @@ -94,7 +93,7 @@ def print_result(msg: str, scalars: Tensor):
Returns:
`pytree` transparently wrapped in an XLA optimization barrier.
"""
return ad_checkpoint._optimization_barrier(pytree) # pylint: disable=protected-access
return jax.lax.optimization_barrier(pytree)


@forward_optimization_barrier.defjvp
Expand Down
3 changes: 1 addition & 2 deletions axlearn/common/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# pylint: disable=no-self-use
import jax
import jaxlib
import numpy as np
import pytest
import tensorflow as tf
Expand Down Expand Up @@ -1017,7 +1016,7 @@ def f(x):
# With runtime_checks enabled, we should be able to crash with jittable checks without
# needing to checkify.
with runtime_checks():
with self.assertRaisesRegex(jaxlib.xla_extension.XlaRuntimeError, "cannot be zero!"):
with self.assertRaisesRegex(jax.errors.JaxRuntimeError, "cannot be zero!"):
jax.jit(f)(0)

def test_prng_impl(self):
Expand Down
Loading