-
Notifications
You must be signed in to change notification settings - Fork 27
Dev/fuyuajin/maxtext backend test #557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0511e5d
2e58d67
a506657
a6f37d8
b96d0b3
3767dbd
cbda947
e0794c9
a704cee
10d1a0f
f09fc6f
a0aed10
62c59e4
eac1364
08967fd
095b267
87bc2e7
fdb9c48
36a162d
ba5c95c
faec60e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| work_group: ${PRIMUS_TEAM:amd} | ||
| user_name: ${PRIMUS_USER:root} | ||
| exp_name: ${PRIMUS_EXP_NAME:llama3.1_405B-pretrain} | ||
| workspace: ./output | ||
|
|
||
| modules: | ||
| pre_trainer: | ||
| framework: maxtext | ||
| config: pre_trainer.yaml | ||
|
|
||
| # model to run | ||
| model: llama3.1_405B.yaml | ||
| overrides: | ||
| run_name: "llama3.1_405b_training" | ||
| base_output_directory: "./output" | ||
| steps: 50 | ||
| log_period: 10 | ||
| profiler: "" | ||
|
|
||
| # data | ||
| dataset_type: "synthetic" | ||
| hf_access_token: ${HF_TOKEN:""} | ||
|
|
||
| # checkpoint | ||
| enable_checkpointing: false | ||
| async_checkpointing: false | ||
|
|
||
| # inter-node parallelism strategy | ||
| dcn_data_parallelism: 1 | ||
| dcn_fsdp_parallelism: -1 | ||
| dcn_pipeline_parallelism: 1 | ||
| dcn_tensor_parallelism: 1 | ||
| dcn_sequence_parallelism: 1 | ||
|
|
||
| # intra-node parallelism strategy | ||
| ici_fsdp_parallelism: -1 | ||
| ici_data_parallelism: 1 | ||
| ici_sequence_parallelism: 1 | ||
| ici_tensor_parallelism: 1 | ||
| ici_pipeline_parallelism: 1 | ||
|
|
||
| remat_policy: 'full' | ||
| optimizer_memory_host_offload: False | ||
| param_scan_axis: 1 | ||
| megablox: False | ||
|
|
||
| use_iota_embed: True | ||
| scan_layers: True | ||
|
|
||
| max_target_length: 8192 | ||
| per_device_batch_size: 5 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -175,32 +175,49 @@ export NCCL_CHECKS_DISABLE=1 | |
|
|
||
| # Set InfiniBand GID index for NCCL communication | ||
| if [ "$USING_AINIC" == "1" ]; then | ||
| export ANP_HOME_DIR=${ANP_HOME_DIR:-"/opt/amd-anp"} | ||
| export RCCL_HOME_DIR=${RCCL_HOME_DIR:-"/opt/rccl"} | ||
| export MPI_HOME_DIR=${MPI_HOME_DIR:-"/opt/ompi"} | ||
| export NCCL_NET_PLUGIN=librccl-anp.so | ||
|
|
||
| LOG_INFO_RANK0 "Using AINIC" | ||
| LOG_INFO_RANK0 "RCCL_HOME_DIR: $RCCL_HOME_DIR" | ||
| LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR" | ||
| LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR" | ||
|
|
||
| # unset NCCL_IB_GID_INDEX | ||
| export NCCL_IB_GID_INDEX=1 | ||
| # export NCCL_IB_ROCE_VERSION_NUM=2 | ||
| export NCCL_MAX_P2P_CHANNELS=56 | ||
| export NCCL_IB_TC=104 | ||
| export NCCL_IB_FIFO_TC=192 | ||
| export NET_OPTIONAL_RECV_COMPLETION=1 | ||
| export NCCL_IB_USE_INLINE=1 | ||
| export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 | ||
| export NCCL_GDR_FLUSH_DISABLE=1 | ||
| export NCCL_DMABUF_ENABLE=0 | ||
| export NCCL_IGNORE_CPU_AFFINITY=1 | ||
| export NCCL_IB_QPS_PER_CONNECTION=1 | ||
|
|
||
| export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib:$LD_LIBRARY_PATH | ||
|
|
||
| if [ "${BACKEND:-}" == "MaxText" ]; then | ||
| # ------- RCCL/NCCL IB Tuning ------- | ||
| export IONIC_LOCKFREE=all | ||
| export NCCL_GDR_COPY_ENABLE=1 | ||
| export NCCL_GDR_FLUSH_DISABLE=1 | ||
| export NCCL_IB_ECE_ENABLE=0 | ||
| export NCCL_IB_FIFO_TC=184 | ||
| export NCCL_IB_GID_INDEX=1 | ||
| export NCCL_IB_PCI_RELAXED_ORDERING=1 | ||
| export NCCL_IB_TC=96 | ||
| export NCCL_IB_USE_INLINE=1 | ||
| export NCCL_IGNORE_CPU_AFFINITY=1 | ||
| export NCCL_PXN_DISABLE=0 | ||
| export NET_OPTIONAL_RECV_COMPLETION=1 | ||
| export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 | ||
| export RCCL_LL128_FORCE_ENABLE=1 | ||
| else | ||
| export ANP_HOME_DIR=${ANP_HOME_DIR:-"/opt/amd-anp"} | ||
| export RCCL_HOME_DIR=${RCCL_HOME_DIR:-"/opt/rccl"} | ||
| export MPI_HOME_DIR=${MPI_HOME_DIR:-"/opt/ompi"} | ||
| export NCCL_NET_PLUGIN=librccl-anp.so | ||
|
|
||
| LOG_INFO_RANK0 "RCCL_HOME_DIR: $RCCL_HOME_DIR" | ||
| LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR" | ||
| LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR" | ||
|
|
||
| # unset NCCL_IB_GID_INDEX | ||
| export NCCL_IB_GID_INDEX=1 | ||
| # export NCCL_IB_ROCE_VERSION_NUM=2 | ||
| export NCCL_MAX_P2P_CHANNELS=56 | ||
| export NCCL_IB_TC=104 | ||
| export NCCL_IB_FIFO_TC=192 | ||
| export NET_OPTIONAL_RECV_COMPLETION=1 | ||
| export NCCL_IB_USE_INLINE=1 | ||
| export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 | ||
| export NCCL_GDR_FLUSH_DISABLE=1 | ||
| export NCCL_DMABUF_ENABLE=0 | ||
| export NCCL_IGNORE_CPU_AFFINITY=1 | ||
| export NCCL_IB_QPS_PER_CONNECTION=1 | ||
|
|
||
| export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib:$LD_LIBRARY_PATH | ||
| fi | ||
| else | ||
| export NCCL_IB_GID_INDEX=3 | ||
| fi | ||
|
|
@@ -272,9 +289,15 @@ if [ "${BACKEND:-}" == "MaxText" ]; then | |
| export DUMP_HLO_DIR=${DUMP_HLO_DIR:-"${PRIMUS_PATH}/output/xla_dump_hlo"} | ||
| export DUMP_HLO=${DUMP_HLO:-0} | ||
| export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 | ||
| export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 | ||
| if [ "${NNODES}" -gt 1 ]; then | ||
| export XLA_PYTHON_CLIENT_MEM_FRACTION=.93 | ||
| export JAX_HIP_GRAPH_LOWERING=false | ||
| else | ||
| export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 | ||
| fi | ||
| export TF_CPP_MIN_LOG_LEVEL=2 # this env var is used to suppress the error logs at the end of training | ||
| export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_command_buffer='' --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_cublaslt=true --xla_gpu_autotune_level=4 --xla_gpu_enable_all_gather_combine_by_dim=false" | ||
| export NVTE_USE_HIPBLASLT=1 | ||
| export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_graph_level=0 --xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True --xla_gpu_autotune_level=0 --xla_gpu_enable_all_gather_combine_by_dim=FALSE" | ||
| if [ "${DUMP_HLO}" = "1" ]; then | ||
| mkdir -p "${DUMP_HLO_DIR}" | ||
| export XLA_FLAGS="$XLA_FLAGS --xla_dump_to=$DUMP_HLO_DIR" | ||
|
|
@@ -403,6 +426,21 @@ if [[ "$PATCH_TE_FLASH_ATTN" == "1" ]]; then | |
| fi | ||
| LOG_INFO_RANK0 "" | ||
|
|
||
| # -------------------- Install required packages for Jax -------------------- | ||
| install_pkgs_for_maxtext() { | ||
| LOG_INFO_RANK0 "========== Install IB required packages for Jax/MaxText ==========" | ||
| apt update | ||
| apt install autoconf automake libtool pkg-config -y | ||
| apt install jq dpkg-dev kmod xz-utils -y | ||
| apt install libibverbs-dev ibverbs-utils infiniband-diags -y | ||
| apt install rdma-core librdmacm-dev libibverbs-dev libibumad-dev -y | ||
| LOG_INFO_RANK0 "========== Install IB required packages for Jax/MaxText Done ==========" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are not for JAX/MaxText libraries per-se, but rather to add missing dependencies not found in the public docker (like rocm/jax-training:maxtext-v26.1), right? @amd-fuyuajin We don't need to do this for megatron or torchtitan jobs? Or is this already installed in those dockers? @wenxie-amd
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These packages are mainly related to InfiniBand/RDMA libraries. I see they are only installed when NNODES > 1 (line 440). They probably provide networking stack for distributed training. Again, @llying-001 added this and can explain better.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, these packages are dependencies required for REBUILD_BNXT that are missing in the public JAX docker image (e.g., rocm/jax-training:maxtext-v26.1), but they are already installed in the Torch docker image (e.g., rocm/primus:v26.1) |
||
| } | ||
|
|
||
| if [[ "$NNODES" -gt 1 ]] && [[ "${BACKEND:-}" == "MaxText" ]]; then | ||
| install_pkgs_for_maxtext | ||
| fi | ||
|
|
||
| # ----------------- Rebuild nbxt ----------------- | ||
| export REBUILD_BNXT=${REBUILD_BNXT:-0} | ||
| export PATH_TO_BNXT_TAR_PACKAGE=${PATH_TO_BNXT_TAR_PACKAGE} | ||
|
|
@@ -423,20 +461,6 @@ else | |
| LOG_INFO "Skip bnxt rebuild. REBUILD_BNXT=$REBUILD_BNXT, PATH_TO_BNXT_TAR_PACKAGE=$PATH_TO_BNXT_TAR_PACKAGE" | ||
| fi | ||
|
|
||
| # -------------------- Install required packages for Jax -------------------- | ||
| install_pkgs_for_maxtext() { | ||
| LOG_INFO_RANK0 "========== Install required packages for Jax/MaxText ==========" | ||
| apt install iproute2 -y | ||
| apt install -y linux-headers-"$(uname -r)" libelf-dev | ||
| apt install -y gcc make libtool autoconf librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils perftest ethtool libibverbs-dev \ | ||
| rdma-core strace libibmad5 libibnetdisc5 ibverbs-providers libibumad-dev libibumad3 libibverbs1 libnl-3-dev libnl-route-3-dev | ||
| LOG_INFO_RANK0 "========== Install required packages for Jax/MaxText Done ==========" | ||
| } | ||
|
|
||
| if [[ "$NNODES" -gt 1 ]] && [[ "${BACKEND:-}" == "MaxText" ]]; then | ||
| install_pkgs_for_maxtext | ||
| fi | ||
|
|
||
| # -------------------- HipBLASLt Tuning -------------------- | ||
| handle_hipblaslt_tuning() { | ||
| local STAGE=${PRIMUS_HIPBLASLT_TUNING_STAGE:-0} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why we need to have different flags (NCCL_IB_TC, NCCL_IB_FIFO_TC) when using MaxText backend or not using MaxText backend? I think these flags are more related to cluster settings, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@llying-001 can explain this better. I did not change any of this part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I extracted these env flags for MaxText backend from https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/jax_maxtext_multinode_benchmark.sh#L305. They are actually related to the cluster instead of backend. Are the env flags in jax_maxtext_multinode_benchmark.sh configured for Vultr cluster? @yeandy
For the Megatron/Titan backend, which cluster are the env flags in run_pretrain.sh configured for? @zhenhuang12
It would be great if we could unify them.