Skip to content
Merged
36 changes: 33 additions & 3 deletions qa/L1_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,42 @@
#
# See LICENSE for license information.

set -xe
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}

RET=0
FAILED_CASES=""

: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

export NVTE_JAX_UNITTEST_LEVEL="L1"

# Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
export XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_dense.xml $TE_PATH/tests/jax/test_distributed_dense.py || test_fail "test_distributed_dense.py"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_helper.xml $TE_PATH/tests/jax/test_distributed_helper.py || test_fail "test_distributed_helper.py"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_layernorm.xml $TE_PATH/tests/jax/test_distributed_layernorm.py || test_fail "test_distributed_layernorm.py"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_mlp.xml $TE_PATH/tests/jax/test_distributed_layernorm_mlp.py || test_fail "test_distributed_layernorm_mlp.py"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_softmax.xml $TE_PATH/tests/jax/test_distributed_softmax.py || test_fail "test_distributed_softmax.py"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py"

# TODO(Phuong): add this test back after it is verified
# SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh || test_fail "test_multi_process_distributed_grouped_gemm.py"

if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
10 changes: 9 additions & 1 deletion tests/jax/multi_process_launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ do
CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_RUNS > /dev/null 2>&1 &
done

CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS
CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS | tee stdout_multi_process.txt

wait

RET=0
if grep -q "FAILED" stdout_multi_process.txt; then
RET=1
fi
Comment on lines +26 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The grep pattern "FAILED" may not detect test failures. The test script at test_multi_process_distributed_grouped_gemm.py:148-150 calls jnp.allclose() without asserting the result, so tests pass even when values don't match. If the test were fixed to use assertions, failures would output tracebacks with "failed" (lowercase) not "FAILED" (uppercase).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure on the exact error string reported here, but I do see jnp.allclose in test_multi_process_distributed_grouped_gemm.py here. From the docs, jnp.allclose returns a bool instead of asserting and the return value isn't used. Should we use our assert_allclose from utils instead? Can be a separate PR as this PR is focused on the test failure reporting issue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I noticed it. This will be addressed in follow-up PR.


rm -f stdout_multi_process.txt
exit "$RET"
7 changes: 6 additions & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,12 @@ def test_norm_forward_with_tensor_scaling_fp8(
)

@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize(
"out_dtype",
[
jnp.float8_e4m3fn,
],
)
def test_norm_forward_with_block_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
):
Expand Down