diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2057e6b67..2d3bc3bad 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -13,7 +13,7 @@ env: PRIMUS_TURBO_COMMIT: 5233748e9c5c5795a6484ab31ece47c442d29ec2 # feat(mxfp4): refactor gemm mxfp4 and mxfp8. fuse transpose, hadamard transform and quantization. (#195) ROCSHMEM_COMMIT: 17ff985c026f9f97f85068647e863ab541dd5645 # Update version to 3.2.0 for 7.2.0 rocm release (#351) (#355) BASE_IMAGE: docker.io/rocm/primus:v26.1 - MAXTEXT_BASE_IMAGE: docker.io/rocm/jax-training:maxtext-v25.9 + MAXTEXT_BASE_IMAGE: docker.io/rocm/jax-training:maxtext-v26.1 jobs: code-lint: @@ -281,7 +281,7 @@ jobs: env: PRIMUS_WORKDIR: /wekafs/primus-data/primus_safe_ci/jax needs: [code-lint] - runs-on: [primus-lm-cicd-jax-8t8mh] + runs-on: [primus-lm-cicd-jax-v26d1-dl6qc] steps: - run: echo "🎉 Begin Primus-Turbo Checkout." - name: Set commit hash to env @@ -304,19 +304,16 @@ jobs: echo "✅ [Pip install requirements] started at: $(date)" mkdir -p ${PRIMUS_WORKDIR}/primus-cache python3 -m pip install --upgrade pip setuptools - pip3 install --cache-dir=${PRIMUS_WORKDIR}/primus-cache --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0 - MAX_JOBS=128 pip3 install --cache-dir=${PRIMUS_WORKDIR}/primus-cache --no-build-isolation --no-clean -r requirements.txt end_time=$(date +%s) elapsed=$((end_time - start_time)) echo "✅ [Pip install requirements] ended at: $(date)" echo "⏱️ [Pip install requirements] Total elapsed time: ${elapsed} seconds" start_time=$(date +%s) echo "✅ [build primus-turbo] started at: $(date)" - PRIMUS_TURBO_FRAMEWORK="JAX" pip3 install --no-build-isolation -e . -v end_time=$(date +%s) elapsed=$((end_time - start_time)) echo "✅ [build primus-turbo] ended at: $(date)" - echo "⏱️ [build primus-turbo] Total elapsed time: ${elapsed} seconds" + echo "⏱️ [build primus-turbo] Torch installation causes segfault, so we skip it and actually not install turbo. Total elapsed time: ${elapsed} seconds" - run: echo "🎉 Begin Primus Unit Test." - uses: actions/checkout@v4 with: diff --git a/docs/cli/PRIMUS-CLI-GUIDE.md b/docs/cli/PRIMUS-CLI-GUIDE.md index 3ec7a2e5b..ff954816b 100644 --- a/docs/cli/PRIMUS-CLI-GUIDE.md +++ b/docs/cli/PRIMUS-CLI-GUIDE.md @@ -36,7 +36,7 @@ ```bash # Run GEMM benchmark directly on current host -./primus-cli direct -- benchmark gemm -M 4096 -N 4096 -K 4096 +./primus-cli direct -- benchmark gemm --M 4096 --N 4096 --K 4096 ``` --- @@ -65,7 +65,7 @@ Primus CLI supports three execution modes, each suitable for different scenarios ./primus-cli direct -- train pretrain --config config.yaml # GEMM benchmark -./primus-cli direct -- benchmark gemm -M 4096 -N 4096 -K 4096 +./primus-cli direct -- benchmark gemm --M 4096 --N 4096 --K 4096 # Environment check (info only) ./primus-cli direct -- preflight --host --gpu --network @@ -115,7 +115,7 @@ Primus CLI supports three execution modes, each suitable for different scenarios # Set resource limits ./primus-cli container --cpus 32 --memory 256G \ - -- benchmark gemm -M 8192 -N 8192 -K 8192 + -- benchmark gemm --M 8192 --N 8192 --K 8192 # Mount local Primus code for development ./primus-cli container --volume ~/workspace/Primus:/workspace/Primus \ @@ -164,10 +164,18 @@ Primus CLI supports three execution modes, each suitable for different scenarios -- train pretrain --config deepseek_v2.yaml # Run distributed GEMM benchmark -./primus-cli slurm srun -N 2 -- benchmark gemm -M 16384 -N 16384 -K 16384 +./primus-cli slurm srun -N 2 -- benchmark gemm --M 16384 --N 16384 --K 16384 # Multi-node environment check (info only) +# this will generate a fast info report of the host, GPU, and network ./primus-cli slurm srun -N 4 -- preflight --host --gpu --network + +# this will generate a full preflight report of the host, GPU, and network, as well as the performance tests +./primus-cli slurm srun -N 4 -- preflight --report-file-name preflight-report-4N + +# if you are using AINIC in your cluster, use the appropriate configuration file +# for preflight test, set docker image to rocm/primus:v26.1 in the configuration file +./primus-cli --config runner/use_ainic.yaml slurm srun -N 2 -- preflight --report-file-name preflight-report-2N ``` **Suitable for**: @@ -250,6 +258,15 @@ direct: ./primus-cli --config prod.yaml slurm srun -N 8 -- train pretrain ``` +### Using AINIC Configuration File + +If you are using AINIC in your cluster, you can use the `runner/use_ainic.yaml` configuration file to configure the AINIC environment. This file includes pre-configured environment variables for AINIC: `USING_AINIC=1`, `NCCL_PXN_DISABLE=0`, and `NCCL_IB_GID_INDEX=1`. You can modify the `NCCL_IB_GID_INDEX` value based on your AINIC settings and update the `image` value to match your Docker image. + +Here is an example of using the AINIC configuration file to run a training job: +```bash +./primus-cli --config runner/use_ainic.yaml slurm srun -N 2 -- train pretrain --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml +``` + ### Configuration Priority **Priority Order** (high to low): @@ -306,13 +323,13 @@ Command-line args > Specified config file > System default config > User config #### GEMM Benchmark ```bash # Single-node GEMM -./primus-cli direct -- benchmark gemm -M 4096 -N 4096 -K 4096 +./primus-cli direct -- benchmark gemm --M 4096 --N 4096 --K 4096 # Run in container -./primus-cli container -- benchmark gemm -M 8192 -N 8192 -K 8192 +./primus-cli container -- benchmark gemm --M 8192 --N 8192 --K 8192 # Multi-node GEMM -./primus-cli slurm srun -N 2 -- benchmark gemm -M 16384 -N 16384 -K 16384 +./primus-cli slurm srun -N 2 -- benchmark gemm --M 16384 --N 16384 --K 16384 ``` #### Other Benchmarks diff --git a/examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml b/examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml index c543f9eb1..2bee81be5 100644 --- a/examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml +++ b/examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml @@ -39,4 +39,4 @@ modules: capacity_factor: 1 max_target_length: 4096 per_device_batch_size: 12 - remat_policy: "minimal" + remat_policy: "save_dot_with_context_except_mlp" diff --git a/examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml b/examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml new file mode 100644 index 000000000..0111a5db7 --- /dev/null +++ b/examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml @@ -0,0 +1,51 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:llama3.1_405B-pretrain} +workspace: ./output + +modules: + pre_trainer: + framework: maxtext + config: pre_trainer.yaml + + # model to run + model: llama3.1_405B.yaml + overrides: + run_name: "llama3.1_405b_training" + base_output_directory: "./output" + steps: 50 + log_period: 10 + profiler: "" + + # data + dataset_type: "synthetic" + hf_access_token: ${HF_TOKEN:""} + + # checkpoint + enable_checkpointing: false + async_checkpointing: false + + # inter-node parallelism strategy + dcn_data_parallelism: 1 + dcn_fsdp_parallelism: -1 + dcn_pipeline_parallelism: 1 + dcn_tensor_parallelism: 1 + dcn_sequence_parallelism: 1 + + # intra-node parallelism strategy + ici_fsdp_parallelism: -1 + ici_data_parallelism: 1 + ici_sequence_parallelism: 1 + ici_tensor_parallelism: 1 + ici_pipeline_parallelism: 1 + + remat_policy: 'full' + optimizer_memory_host_offload: False + param_scan_axis: 1 + megablox: False + + use_iota_embed: True + scan_layers: True + + max_target_length: 8192 + per_device_batch_size: 5 diff --git a/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml b/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml index c543f9eb1..f680155ac 100644 --- a/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml +++ b/examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml @@ -38,5 +38,5 @@ modules: megablox: false capacity_factor: 1 max_target_length: 4096 - per_device_batch_size: 12 + per_device_batch_size: 11 remat_policy: "minimal" diff --git a/examples/run_local_pretrain.sh b/examples/run_local_pretrain.sh index 4cb36fdb3..d02da3f63 100755 --- a/examples/run_local_pretrain.sh +++ b/examples/run_local_pretrain.sh @@ -42,7 +42,7 @@ EXP=${EXP:-"examples/megatron/exp_pretrain.yaml"} # Default docker image if [ "${BACKEND:-}" = "MaxText" ]; then - DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/jax-training:maxtext-v25.9"} + DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/jax-training:maxtext-v26.1"} else DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/primus:v26.1"} fi @@ -125,6 +125,13 @@ if [ "$USING_AINIC" == "1" ]; then ENV_ARGS+=("--env" "ANP_HOME_DIR") ENV_ARGS+=("--env" "MPI_HOME_DIR") + TC_RESULTS=$(bash "${PRIMUS_PATH}/examples/scripts/detect_nccl_ib_tc.sh") + if [ -z "$TC_RESULTS" ]; then + echo "TC_RESULTS: $TC_RESULTS" + ENV_ARGS+=("--env" "TC_RESULTS") + else + echo "Failed to detect NCCL_IB_TC and NCCL_IB_FIFO_TC" + fi # VOLUME_ARGS+=(-v /mnt/shared:/mnt/shared) # VOLUME_ARGS+=(-v /etc/libibverbs.d/:/etc/libibverbs.d:ro) # VOLUME_ARGS+=(-v /usr/lib/x86_64-linux-gnu/libibverbs/:/usr/lib/x86_64-linux-gnu/libibverbs/:ro) diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index 016bc23db..930a5f2a9 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -175,40 +175,59 @@ export NCCL_CHECKS_DISABLE=1 # Set InfiniBand GID index for NCCL communication if [ "$USING_AINIC" == "1" ]; then - export ANP_HOME_DIR=${ANP_HOME_DIR:-"/opt/amd-anp"} - export RCCL_HOME_DIR=${RCCL_HOME_DIR:-"/opt/rccl"} - export MPI_HOME_DIR=${MPI_HOME_DIR:-"/opt/ompi"} - # Check which NCCL net plugin library is present under ${ANP_HOME_DIR}/build and set accordingly - if [ -f "${ANP_HOME_DIR}/build/librccl-anp.so" ]; then - export NCCL_NET_PLUGIN=librccl-anp.so - elif [ -f "${ANP_HOME_DIR}/build/librccl-net.so" ]; then - export NCCL_NET_PLUGIN=librccl-net.so - else - LOG_ERROR "Error: Neither librccl-anp.so nor librccl-net.so found in ${ANP_HOME_DIR}/build." - exit 1 - fi - LOG_INFO_RANK0 "Using AINIC" - LOG_INFO_RANK0 "RCCL_HOME_DIR: $RCCL_HOME_DIR" - LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR" - LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR" - # unset NCCL_IB_GID_INDEX export NCCL_IB_GID_INDEX=1 # export NCCL_IB_ROCE_VERSION_NUM=2 - export NCCL_MAX_P2P_CHANNELS=56 - export NCCL_IB_TC=104 - export NCCL_IB_FIFO_TC=192 + if [ -z "${TC_RESULTS:-}" ]; then + export NCCL_IB_TC=${NCCL_IB_TC:-104} + export NCCL_IB_FIFO_TC=${NCCL_IB_FIFO_TC:-192} + else + read -r NCCL_IB_TC NCCL_IB_FIFO_TC <<< "$TC_RESULTS" + export NCCL_IB_TC + export NCCL_IB_FIFO_TC + fi export NET_OPTIONAL_RECV_COMPLETION=1 export NCCL_IB_USE_INLINE=1 export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 export NCCL_GDR_FLUSH_DISABLE=1 - export NCCL_DMABUF_ENABLE=0 export NCCL_IGNORE_CPU_AFFINITY=1 - export NCCL_IB_QPS_PER_CONNECTION=1 + LOG_INFO_RANK0 "NCCL_IB_TC: $NCCL_IB_TC" + LOG_INFO_RANK0 "NCCL_IB_FIFO_TC: $NCCL_IB_FIFO_TC" + + if [ "${BACKEND:-}" == "MaxText" ]; then + # ------- RCCL/NCCL IB Tuning ------- + export IONIC_LOCKFREE=all + export NCCL_GDR_COPY_ENABLE=1 + export NCCL_IB_ECE_ENABLE=0 + export NCCL_IB_PCI_RELAXED_ORDERING=1 + + export NCCL_PXN_DISABLE=0 + export RCCL_LL128_FORCE_ENABLE=1 + else + export ANP_HOME_DIR=${ANP_HOME_DIR:-"/opt/amd-anp"} + export RCCL_HOME_DIR=${RCCL_HOME_DIR:-"/opt/rccl"} + export MPI_HOME_DIR=${MPI_HOME_DIR:-"/opt/ompi"} + # Check which NCCL net plugin library is present under ${ANP_HOME_DIR}/build and set accordingly + if [ -f "${ANP_HOME_DIR}/build/librccl-anp.so" ]; then + export NCCL_NET_PLUGIN=librccl-anp.so + elif [ -f "${ANP_HOME_DIR}/build/librccl-net.so" ]; then + export NCCL_NET_PLUGIN=librccl-net.so + else + LOG_ERROR "Error: Neither librccl-anp.so nor librccl-net.so found in ${ANP_HOME_DIR}/build." + exit 1 + fi - export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib:$LD_LIBRARY_PATH + LOG_INFO_RANK0 "RCCL_HOME_DIR: $RCCL_HOME_DIR" + LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR" + LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR" + export NCCL_MAX_P2P_CHANNELS=56 + export NCCL_DMABUF_ENABLE=0 + export NCCL_IB_QPS_PER_CONNECTION=1 + + export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib:$LD_LIBRARY_PATH + fi else export NCCL_IB_GID_INDEX=3 fi @@ -280,9 +299,15 @@ if [ "${BACKEND:-}" == "MaxText" ]; then export DUMP_HLO_DIR=${DUMP_HLO_DIR:-"${PRIMUS_PATH}/output/xla_dump_hlo"} export DUMP_HLO=${DUMP_HLO:-0} export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 - export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 + if [ "${NNODES}" -gt 1 ]; then + export XLA_PYTHON_CLIENT_MEM_FRACTION=.93 + export JAX_HIP_GRAPH_LOWERING=false + else + export XLA_PYTHON_CLIENT_MEM_FRACTION=.97 + fi + export TF_CPP_MIN_LOG_LEVEL=2 # this env var is used to suppress the error logs at the end of training + export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_command_buffer='' --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_cublaslt=true --xla_gpu_autotune_level=4 --xla_gpu_enable_all_gather_combine_by_dim=false" export NVTE_USE_HIPBLASLT=1 - export XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_graph_level=0 --xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True --xla_gpu_autotune_level=0 --xla_gpu_enable_all_gather_combine_by_dim=FALSE" if [ "${DUMP_HLO}" = "1" ]; then mkdir -p "${DUMP_HLO_DIR}" export XLA_FLAGS="$XLA_FLAGS --xla_dump_to=$DUMP_HLO_DIR" @@ -411,6 +436,21 @@ if [[ "$PATCH_TE_FLASH_ATTN" == "1" ]]; then fi LOG_INFO_RANK0 "" +# -------------------- Install required packages for Jax -------------------- +install_pkgs_for_maxtext() { + LOG_INFO_RANK0 "========== Install IB required packages for Jax/MaxText ==========" + apt update + apt install autoconf automake libtool pkg-config -y + apt install jq dpkg-dev kmod xz-utils -y + apt install libibverbs-dev ibverbs-utils infiniband-diags -y + apt install rdma-core librdmacm-dev libibverbs-dev libibumad-dev -y + LOG_INFO_RANK0 "========== Install IB required packages for Jax/MaxText Done ==========" +} + +if [[ "$NNODES" -gt 1 ]] && [[ "${BACKEND:-}" == "MaxText" ]]; then + install_pkgs_for_maxtext +fi + # ----------------- Rebuild nbxt ----------------- export REBUILD_BNXT=${REBUILD_BNXT:-0} export PATH_TO_BNXT_TAR_PACKAGE=${PATH_TO_BNXT_TAR_PACKAGE} @@ -431,20 +471,6 @@ else LOG_INFO "Skip bnxt rebuild. REBUILD_BNXT=$REBUILD_BNXT, PATH_TO_BNXT_TAR_PACKAGE=$PATH_TO_BNXT_TAR_PACKAGE" fi -# -------------------- Install required packages for Jax -------------------- -install_pkgs_for_maxtext() { - LOG_INFO_RANK0 "========== Install required packages for Jax/MaxText ==========" - apt install iproute2 -y - apt install -y linux-headers-"$(uname -r)" libelf-dev - apt install -y gcc make libtool autoconf librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils perftest ethtool libibverbs-dev \ - rdma-core strace libibmad5 libibnetdisc5 ibverbs-providers libibumad-dev libibumad3 libibverbs1 libnl-3-dev libnl-route-3-dev - LOG_INFO_RANK0 "========== Install required packages for Jax/MaxText Done ==========" -} - -if [[ "$NNODES" -gt 1 ]] && [[ "${BACKEND:-}" == "MaxText" ]]; then - install_pkgs_for_maxtext -fi - # -------------------- HipBLASLt Tuning -------------------- handle_hipblaslt_tuning() { local STAGE=${PRIMUS_HIPBLASLT_TUNING_STAGE:-0} diff --git a/examples/run_slurm_pretrain.sh b/examples/run_slurm_pretrain.sh index 04da35a4d..d13e6638f 100755 --- a/examples/run_slurm_pretrain.sh +++ b/examples/run_slurm_pretrain.sh @@ -19,6 +19,7 @@ Optional Environment Variables: NNODES Number of nodes to use [default: 1] MASTER_PORT Master port [default: 12345] LOG_DIR Directory for log output [default: ./output] + NODE_LIST Comma-separated list of nodes for srun --nodelist [default: unset] Example: export DATA_PATH=/mnt/data @@ -35,12 +36,13 @@ export NNODES=${NNODES:-1} SCRIPT_DIR=$(dirname "$(realpath "${BASH_SOURCE[0]}")") export LOG_DIR=${LOG_DIR:-"./output"} -LOG_FILE="${LOG_DIR}/log_slurm_pretrain.txt" +LOG_FILE="${LOG_DIR}/log_slurm_pretrain_$(date +%Y%m%d_%H%M%S).txt" mkdir -p "$LOG_DIR" srun -N "${NNODES}" \ --exclusive \ --export ALL \ + ${NODE_LIST:+--nodelist="${NODE_LIST}"} \ --ntasks-per-node=1 \ --cpus-per-task="${CPUS_PER_TASK:-128}" \ bash -c " @@ -58,5 +60,5 @@ srun -N "${NNODES}" \ export NODE_RANK=\${SLURM_PROCID} export GPUS_PER_NODE=\${SLURM_GPUS_ON_NODE} export REBUILD_PRIMUS_TURBO=\${REBUILD_PRIMUS_TURBO} - bash ${SCRIPT_DIR}/run_local_pretrain.sh \"\$@\" 2>&1 | tee ${LOG_FILE} - " bash "$@" + bash ${SCRIPT_DIR}/run_local_pretrain.sh \"\$@\" + " bash "$@" 2>&1 | tee "${LOG_FILE}" diff --git a/examples/scripts/detect_nccl_ib_tc.sh b/examples/scripts/detect_nccl_ib_tc.sh new file mode 100644 index 000000000..e24f0cf35 --- /dev/null +++ b/examples/scripts/detect_nccl_ib_tc.sh @@ -0,0 +1,110 @@ +#!/bin/bash +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +# +# Auto-detect correct NCCL_IB_TC and NCCL_IB_FIFO_TC for Pensando AINIC clusters. +# Reads QoS DSCP-to-priority mapping from nicctl and finds the PFC-protected DSCP. +# +# Usage: +# source detect_nccl_ib_tc.sh # sets NCCL_IB_TC and NCCL_IB_FIFO_TC +# eval $(./detect_nccl_ib_tc.sh) # alternative: export from subshell + +set -euo pipefail + +is_pensando() { + local ib_dev="" + + for dev in /sys/class/infiniband/*; do + [ -e "$dev" ] || continue + ib_dev=$(basename "$dev") + break + done + [ -z "$ib_dev" ] && return 1 + + if echo "$ib_dev" | grep -qi "ionic"; then + return 0 + fi + + local ca_type + ca_type=$(ibstat "$ib_dev" 2>/dev/null | grep "CA type:" | head -1 || true) + echo "$ca_type" | grep -qi "Pensando" +} + +detect_pensando_tc() { + if ! command -v nicctl &>/dev/null; then + echo "WARN: nicctl not found, using known Pensando defaults" >&2 + echo "104 192" + return + fi + + local qos_output + qos_output=$(nicctl show qos 2>/dev/null) || { + echo "WARN: nicctl show qos failed, using defaults" >&2 + echo "104 192" + return + } + + local pfc_prio + pfc_prio=$(echo "$qos_output" | grep "PFC no-drop priorities" | head -1 | awk '{print $NF}') + + if [ -z "$pfc_prio" ]; then + echo "WARN: Could not determine PFC priority, using defaults" >&2 + echo "104 192" + return + fi + + # nicctl output lines look like: + # DSCP : 26 ==> priority : 3 + # DSCP bitmap : 0x0000000004000000 ==> priority : 3 + # DSCP : 0-25, 27-47, 49-63 ==> priority : 0 + # We want the single DSCP (not bitmap, not range) that maps to each priority. + + # Helper: extract single-value DSCP for a given priority + extract_dscp_for_priority() { + echo "$qos_output" \ + | grep -v "bitmap" \ + | grep "DSCP" \ + | grep "==> priority : ${1}$" \ + | head -1 \ + | sed 's/.*DSCP[^:]*: *//' \ + | sed 's/ *==> .*//' \ + | tr -d ' ' + } + + # NCCL_IB_TC: use DSCP that maps to PFC-protected (no-drop) priority + local data_dscp + data_dscp=$(extract_dscp_for_priority "$pfc_prio") + + if ! echo "$data_dscp" | grep -qE '^[0-9]+$'; then + echo "WARN: Could not parse DSCP for PFC priority $pfc_prio, using defaults" >&2 + echo "104 192" + return + fi + + # NCCL_IB_FIFO_TC: use DSCP that maps to the strict-priority queue + # (scheduling output: priority N has "strict" type) + local strict_prio + strict_prio=$(echo "$qos_output" | grep -i "strict" | head -1 | awk '{print $1}') + local fifo_dscp="" + if [ -n "$strict_prio" ] && echo "$strict_prio" | grep -qE '^[0-9]+$'; then + fifo_dscp=$(extract_dscp_for_priority "$strict_prio") + fi + + if ! echo "$fifo_dscp" | grep -qE '^[0-9]+$'; then + echo "WARN: Could not find strict-priority DSCP, using same as data" >&2 + fifo_dscp="$data_dscp" + fi + + echo "$((data_dscp * 4)) $((fifo_dscp * 4))" +} + +if ! is_pensando; then + echo "# Not a Pensando AINIC cluster, no NCCL_IB_TC override needed" >&2 + exit 0 +fi + +result=$(detect_pensando_tc) +echo "$result" diff --git a/primus/backends/maxtext/checkpointing.py b/primus/backends/maxtext/checkpointing.py index a04dfc4b1..8c493f18a 100644 --- a/primus/backends/maxtext/checkpointing.py +++ b/primus/backends/maxtext/checkpointing.py @@ -7,9 +7,185 @@ from typing import Any +import jax import orbax.checkpoint as ocp +import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager +import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager from etils import epath +from flax.training import train_state from MaxText import max_logging +from MaxText.checkpointing import ( + _load_full_state_from_path, + _replica_devices, + _restore_grain_iterator, + load_params_from_path, +) +from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator +from MaxText.multihost_dataloading import MultiHostDataLoadIterator + +Composite = ocp.args.Composite +EmergencyCheckpointManager = emergency_checkpoint_manager.CheckpointManager +EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager + + +def load_state_if_possible( + checkpoint_manager: ocp.CheckpointManager | None, + data_iterator: MultiHostDataLoadIterator | list[MultiHostDataLoadIterator] | None, + load_parameters_from_path: str, + load_full_state_from_path: str, + checkpoint_storage_concurrent_gb: int, + abstract_unboxed_pre_state: train_state.TrainState, + enable_single_replica_ckpt_restoring: bool | None = False, + dataset_type: str | None = "tfds", + step: int = -1, # -1 means latest + use_ocdbt=True, + use_zarr3=True, + enable_orbax_v1=False, + checkpoint_conversion_fn=None, + source_checkpoint_layout="orbax", + expansion_factor_real_data: int = -1, +): + """Loads TrainState as possible from the inputs. + + Args: + checkpoint_manager: if the checkpoint_manager has a valid checkpoint, return + that TrainState. This enables a full reload of a run in progress. + load_parameters_from_path: if there is no checkpoint in the checkpoint + manager, load parameters from a parameter only checkpoint at this path. + load_full_state_from_path: if there is no checkpoint in the checkpoint + manager, load full state from a full state checkpoint at this path. + abstract_unboxed_pre_state: an unboxed, abstract TrainState that Orbax + matches type against. + enable_single_replica_ckpt_restoring: bool flag for restoring checkpoitn + with SingleReplicaArrayHandler + checkpoint_storage_concurrent_gb: concurrent GB for checkpoint byte I/O. + enable_orbax_v1: bool flag for enabling Orbax v1. + checkpoint_conversion_fn: function for converting checkpoint to Orbax v1. + source_checkpoint_layout: Optional checkpoint context to use for loading, + provided in string format with the default being "orbax". + + Returns: + A tuple of (train_state, train_state_params) where full_train_state captures + a full reload and train_state_params just the params for a partial reload. + At most one will be non-None. Both can be None if neither checkpoint is + set. + """ + + if checkpoint_manager is not None: + max_logging.log("checkpoint manager exists so trying to load this run's existing checkpoint") + + step = checkpoint_manager.latest_step() if step < 0 else step + if step is not None: + max_logging.log(f"restoring from this run's directory step {step}") + + def map_to_pspec(data): + if not enable_single_replica_ckpt_restoring: + return ocp.type_handlers.ArrayRestoreArgs(sharding=data.sharding) + pspec = data.sharding.spec + mesh = data.sharding.mesh + replica_axis_index = 0 + replica_devices = _replica_devices(mesh.devices, replica_axis_index) + replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) + single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) + + return ocp.type_handlers.SingleReplicaArrayRestoreArgs( + sharding=jax.sharding.NamedSharding(mesh, pspec), + single_replica_sharding=single_replica_sharding, + global_shape=data.shape, + dtype=data.dtype, + ) + + # Cache the original ArrayHandler before potentially overriding it. + # This is the same handler used when enable_single_replica_ckpt_restoring=False. + original_array_handler = ocp.type_handlers.get_type_handler(jax.Array) + + # Register SingleReplicaArrayHandler globally for restore (if enabled) + if enable_single_replica_ckpt_restoring: + single_replica_handler = ocp.type_handlers.SingleReplicaArrayHandler( + replica_axis_index=0, + broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit + ) + ocp.type_handlers.register_type_handler(jax.Array, single_replica_handler, override=True) + + restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) + checkpoint_args = ocp.args.PyTreeRestore( + item=abstract_unboxed_pre_state, restore_args=restore_args + ) + + def _restore_original_array_handler(): + """Restore the original ArrayHandler after SingleReplicaArrayHandler restore. + + This is critical because SingleReplicaArrayHandler is designed for restore only. + Using it for saves will cause missing array_metadatas files and checkpoint failures. + We restore the EXACT handler that was in place before, not a new instance. + """ + if enable_single_replica_ckpt_restoring: + max_logging.log( + "Restoring original ArrayHandler after SingleReplicaArrayHandler restore..." + ) + # Re-register the original handler that was cached before the override + ocp.type_handlers.register_type_handler(jax.Array, original_array_handler, override=True) + max_logging.log("Original ArrayHandler restored successfully.") + + match (checkpoint_manager, dataset_type, data_iterator): + # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager + # or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and + # 'data_iterator' can be any value and aren't used in this pattern. + case (checkpoint_manager, _, _) if isinstance( + checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) + ): + result = ( + checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state, + None, + ) + _restore_original_array_handler() + return result + # Case 2: Matches if dataset type is "grain" and the data iterator is not a + # PlaceHolderDataIterator and a specific checkpoint file exists for the iterator + case ( + checkpoint_manager, + dataset_type, + data_iterator, + ) if ( + dataset_type == "grain" + and data_iterator + and not isinstance(data_iterator, PlaceHolderDataIterator) + and (checkpoint_manager.directory / str(step) / "iter").exists() + ): + result = _restore_grain_iterator( + checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data + ) + _restore_original_array_handler() + return result + # Case 3: Default/Fallback case. + # This case acts as a wildcard ('_') and matches if none of the preceding cases were met. + case _: + result = (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) + _restore_original_array_handler() + return result + + if load_parameters_from_path != "": + restored_params = load_params_from_path( + load_parameters_from_path, + abstract_unboxed_pre_state.params, + checkpoint_storage_concurrent_gb, + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + ) + return None, restored_params + elif load_full_state_from_path != "": + max_logging.log(f"Loading full state from path: {load_full_state_from_path}") + restored_state = _load_full_state_from_path( + path=load_full_state_from_path, + abstract_unboxed_pre_state=abstract_unboxed_pre_state, + enable_orbax_v1=enable_orbax_v1, + checkpoint_conversion_fn=checkpoint_conversion_fn, + source_checkpoint_layout=source_checkpoint_layout, + ) + return {"items": restored_state}, None + else: + max_logging.log("No existing checkpoints found, not restoring checkpoint.") + return None, None def create_orbax_checkpoint_manager( @@ -30,17 +206,18 @@ def create_orbax_checkpoint_manager( max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}") + # Base configuration for all dataset types + item_names = ("items",) + # we need to use ocdbt and zarr3 to control max file size in the checkpoint + item_handlers = {"items": ocp.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)} + if dataset_type == "grain": - item_names = ("items", "iter") - else: - item_names = ("items",) + item_names += ("iter",) + item_handlers["iter"] = ocp.GrainCheckpointHandler() # local storage checkpoint needs parent directory created p = epath.Path(checkpoint_dir) p.mkdir(exist_ok=True, parents=True) - # we need to use ocdbt and zarr3 to control max file size in the checkpoint - # omitting `iter` uses default handler for `iter` - item_handlers = {"items": ocp.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)} manager = ocp.CheckpointManager( p, item_names=item_names, diff --git a/primus/backends/maxtext/configs/__init__.py b/primus/backends/maxtext/configs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/primus/backends/maxtext/configs/types.py b/primus/backends/maxtext/configs/types.py new file mode 100644 index 000000000..a6cf03703 --- /dev/null +++ b/primus/backends/maxtext/configs/types.py @@ -0,0 +1,239 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +import os +from typing import Any + +from MaxText.configs.types import ( # Run and Checkpointing; Data Types and Quantization; Core Model Architecture; Attention Mechanisms; Mixture of Experts; Parallelism and Layout; Training, Optimization, and Fine-Tuning; Reinforcement Learning; Positional Embeddings; Dataset Loading and Tokenization; Inference; Development and Debugging; Metrics and Monitoring; Multimodal; Derived + AOT, + GRPO, + MTP, + VLLM, + AdamW, + Attention, + Checkpointing, + DatasetGeneral, + DataTypes, + DcnParallelism, + Debug, + Decoding, + DeepSeekMoE, + DerivedValues, + DevelopmentAndDebugging, + EmergencyCheckpointing, + FineTuning, + GcpMonitoring, + Goodput, + GrainDataset, + HardwareAndMesh, + HfDataset, + HloDump, + IciParallelism, + InferenceBenchmark, + InferenceGeneral, + InferenceLayout, + InferenceServer, + LayoutAndSharding, + Llama4Attention, + Logits, + MaxTextConfig, + Metrics, + MlaAttention, + MoBa, + ModelArchitecture, + MoEGeneral, + MoEKernels, + MultimodalGeneral, + Optimizer, + OrbaxStorage, + PagedAttention, + PipelineParallelism, + PositionalEmbedding, + PrefixCaching, + Profiling, + Quantization, + Qwen3Next, + RematAndOffload, + Reward, + RLDataset, + RLEvaluation, + RLHardware, + Rope, + RunInfo, + SpecialTokens, + SplashAttention, + StackTrace, + Tensorboard, + TfdsDataset, + Tokenizer, + TrainingLoop, + VisionProjector, + VisionTower, + YarnRope, +) +from pydantic import BaseModel, ConfigDict, model_validator +from pydantic.fields import Field + + +class PrimusMoEGeneral(MoEGeneral): + expert_balance: bool = Field(False, description="Whether to use expert balancing.") + + +class PrimusDevelopmentAndDebugging(DevelopmentAndDebugging): + jax_distributed_heartbeat_timeout_seconds: int = Field( + 100, + description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores.", + ) + + +class PrimusTurboConfig(BaseModel): + enable_primus_turbo: bool = Field(False, description="Whether to enable Primus Turbo.") + use_turbo_grouped_gemm: bool = Field(False, description="Whether to use turbo grouped gemm.") + + +class PrimusWandbConfig(BaseModel): + enable_wandb: bool = Field(False, description="Whether to enable WandB.") + wandb_project: None | str = Field(None, description="The name of the WandB project.") + wandb_exp_name: None | str = Field( + None, description="The name of the WandB experiment, derived from the run_name if not set." + ) + wandb_save_dir: None | str = Field(None, description="The directory to save the WandB logs.") + + +class PrimusMaxTextConfig( + # Run and Checkpointing + RunInfo, + Checkpointing, + OrbaxStorage, + EmergencyCheckpointing, + # Data Types and Quantization + DataTypes, + Quantization, + # Core Model Architecture + ModelArchitecture, + MTP, + Logits, + # Attention Mechanisms + Attention, + MlaAttention, + MoBa, + Llama4Attention, + SplashAttention, + PagedAttention, + # Mixture of Experts - REPLACED with PrimusMoEGeneral + PrimusMoEGeneral, # Replaces MoEGeneral + MoEKernels, + DeepSeekMoE, + Qwen3Next, + # Parallelism and Layout + HardwareAndMesh, + LayoutAndSharding, + DcnParallelism, + IciParallelism, + PipelineParallelism, + # Training, Optimization, and Fine-Tuning + RematAndOffload, + TrainingLoop, + Optimizer, + AdamW, + FineTuning, + # Reinforcement Learning + RLHardware, + VLLM, + GRPO, + RLDataset, + RLEvaluation, + Reward, + SpecialTokens, + # Positional Embeddings + PositionalEmbedding, + Rope, + YarnRope, + # Dataset Loading and Tokenization + DatasetGeneral, + TfdsDataset, + HfDataset, + GrainDataset, + Tokenizer, + # Inference + InferenceGeneral, + Decoding, + InferenceLayout, + InferenceServer, + InferenceBenchmark, + PrefixCaching, + # Development and Debugging - REPLACED with PrimusDevelopmentAndDebugging + AOT, + PrimusDevelopmentAndDebugging, # Replaces DevelopmentAndDebugging + Profiling, + HloDump, + StackTrace, + # Metrics and Monitoring + Metrics, + Goodput, + GcpMonitoring, + Tensorboard, + # Multimodal + MultimodalGeneral, + VisionTower, + VisionProjector, + # Primus-specific configs - ADDED + PrimusTurboConfig, + PrimusWandbConfig, + # Derived + DerivedValues, +): + """ + The main configuration object for Primus MaxText. + + This class extends MaxTextConfig with Primus-specific configurations: + - Replaces MoEGeneral with PrimusMoEGeneral (adds expert_balance) + - Replaces DevelopmentAndDebugging with PrimusDevelopmentAndDebugging (adds jax_distributed_heartbeat_timeout_seconds) + - Adds PrimusTurboConfig (Primus Turbo optimizations) + - Adds PrimusWandbConfig (WandB integration) + + All other functionality from MaxTextConfig is preserved. + """ + + debug: Debug = Field(default_factory=Debug) + model_config = ConfigDict(extra="forbid", protected_namespaces=()) + + @model_validator(mode="before") + @classmethod + def load_model_specific_defaults(cls, values: dict[str, Any]) -> dict[str, Any]: + """This method is a no-op because `pyconfig` handles model-specific config loading.""" + return values + + @model_validator(mode="after") + def set_derived_and_validate_values(self) -> "PrimusMaxTextConfig": + """ + Computes all derived values and runs all cross-field validations after initial parsing. + This calls the MaxTextConfig's validation logic and then adds any Primus-specific validations. + """ + # Call MaxTextConfig's validation logic directly since we're using composition via multiple inheritance + # rather than direct inheritance. MaxTextConfig.set_derived_and_validate_values expects a MaxTextConfig + # instance, but since we have all the same base classes, we can call it on self. + # We need to temporarily cast self to MaxTextConfig for the method call, or call the method directly. + # Actually, since MaxTextConfig's method works on the same fields we have, we can call it directly. + MaxTextConfig.set_derived_and_validate_values(self) + + # Add any Primus-specific validations here if needed + if (self.wandb_save_dir is None or self.wandb_save_dir == "") and self.base_output_directory: + self.wandb_save_dir = os.path.join(self.base_output_directory, "wandb") + + if self.wandb_project is None or self.wandb_project == "": + self.wandb_project = os.getenv("WANDB_PROJECT", "Primus-MaxText-Pretrain") + + if (self.wandb_exp_name is None or self.wandb_exp_name == "") and self.run_name: + self.wandb_exp_name = self.run_name + + if self.enable_wandb and "WANDB_API_KEY" not in os.environ: + raise ValueError("WANDB_API_KEY is not set. Please set it or login wandb before proceeding") + + if not self.enable_primus_turbo: + self.use_turbo_grouped_gemm = False + + return self diff --git a/primus/backends/maxtext/input_pipeline/_hf_data_processing.py b/primus/backends/maxtext/input_pipeline/_hf_data_processing.py index 641ac3845..79c059615 100644 --- a/primus/backends/maxtext/input_pipeline/_hf_data_processing.py +++ b/primus/backends/maxtext/input_pipeline/_hf_data_processing.py @@ -12,11 +12,9 @@ import numpy as np import transformers from MaxText import multihost_dataloading -from MaxText.input_pipeline import _input_pipeline_utils +from MaxText.input_pipeline import _input_pipeline_utils, instruction_data_processing from MaxText.input_pipeline._hf_data_processing import vision_sft_preprocessing_pipeline -from .custom_packed_batch import CustomPackAndBatchOperation - def preprocessing_pipeline( dataloading_host_index, @@ -31,18 +29,19 @@ def preprocessing_pipeline( max_target_length, shuffle, data_shuffle_seed, + chat_template_path="", add_bos=True, add_eos=True, packing=True, shift=True, num_threads=1, - drop_remainder=False, - generate_padding_example=False, + drop_remainder=True, + generate_padding_batch=False, use_dpo=None, use_sft=None, sft_train_on_completion_only=True, grain_worker_count=1, # only support 0 or 1 - max_segments=1, # max segments per sequence + max_segments_per_seq=1, # max segments per sequence ): """pipeline for preprocessing HF dataset""" assert ( @@ -63,10 +62,16 @@ def preprocessing_pipeline( if use_sft: dataset = dataset.select_columns(data_column_names) - supported_columns = [["prompt", "completion"], ["messages"]] + supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]] assert any( set(data_column_names) == set(supported) for supported in supported_columns ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_column_names}" + + # convert instruction dataset to conversational format + dataset, data_column_names = instruction_data_processing.convert_to_conversational_format( + dataset=dataset, data_columns=data_column_names, chat_template_path=chat_template_path + ) + assert _input_pipeline_utils.is_conversational( dataset.features, data_column_names ), "Dataset is not in conversational format." @@ -119,7 +124,6 @@ def preprocessing_pipeline( dataloading_host_index, dataloading_host_count, num_threads, - generate_padding_example, max_target_length, data_column_names, ) @@ -147,17 +151,17 @@ def lists2array(x): data_column_names = ("inputs", "targets") if packing and not use_dpo: - # monkey patch the splitter to handle TE's maximum segment limitation length_struct = {col: max_target_length for col in data_column_names} - pack_and_batch = CustomPackAndBatchOperation( - batch_size=global_batch_size // jax.process_count(), - length_struct=length_struct, - max_segments=max_segments, + operations.append( + grain.experimental.PackAndBatchOperation( + batch_size=global_batch_size // jax.process_count(), + length_struct=length_struct, + max_sequences_per_bin=max_segments_per_seq, + ) ) - operations.append(pack_and_batch) operations.append(_input_pipeline_utils.ReformatPacking(data_column_names)) else: - operations.append(_input_pipeline_utils.PadToMaxLength(max_target_length, pad_id)) + operations.append(_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) operations.append( grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder) ) @@ -189,7 +193,9 @@ def lists2array(x): read_options=grain.ReadOptions(num_threads=num_threads, prefetch_buffer_size=128), ) - multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh) + multihost_gen = multihost_dataloading.MultiHostDataLoadIterator( + dataloader, global_mesh, generate_padding_batch + ) # Return multi-host jax.Array prep iterator return multihost_gen @@ -237,11 +243,12 @@ def make_hf_train_iterator( add_bos=config.add_bos, add_eos=config.add_eos, packing=config.packing, - generate_padding_example=False, + generate_padding_batch=config.generate_padding_batch_train, use_dpo=config.use_dpo, use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, - max_segments=config.max_segments, + chat_template_path=config.chat_template_path, + max_segments_per_seq=config.max_segments_per_seq, ) return train_iter @@ -261,7 +268,6 @@ def make_hf_eval_iterator( token=config.hf_access_token, ) - eval_generate_padding_example = config.eval_steps > 0 if config.use_sft and config.use_multimodal: eval_iter = vision_sft_preprocessing_pipeline( dataset=eval_ds, @@ -290,10 +296,11 @@ def make_hf_eval_iterator( add_bos=config.add_bos, add_eos=config.add_eos, packing=config.packing, - generate_padding_example=eval_generate_padding_example, + generate_padding_batch=config.generate_padding_batch_eval, use_dpo=config.use_dpo, use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, - max_segments=config.max_segments, + chat_template_path=config.chat_template_path, + max_segments_per_seq=config.max_segments_per_seq, ) return eval_iter diff --git a/primus/backends/maxtext/input_pipeline/custom_packed_batch.py b/primus/backends/maxtext/input_pipeline/custom_packed_batch.py deleted file mode 100644 index 3f6b2e21e..000000000 --- a/primus/backends/maxtext/input_pipeline/custom_packed_batch.py +++ /dev/null @@ -1,215 +0,0 @@ -############################################################################### -# Copyright 2023–2025 Google LLC. All rights reserved. -# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -############################################################################### -""" -Forked from https://github.com/google/grain/blob/7841100258c90c77fcebdd668232aea9c0314fc2/grain/_src/python/experimental/example_packing/packing.py - -Customized packing based on MaxText's default -Modified to support max segments per sequence - -""" - -import dataclasses -from typing import Any, Generic, Iterator, TypeVar, Union, cast - -import numpy as np -from absl import logging -from grain._src.core import tree_lib -from grain._src.python import record - -_T = TypeVar("_T") - - -class _PackedBatch: - """Class to represent a batch of packed examples.""" - - def __init__( - self, - element_for_shapes: Any, # PyTree[np.ndarray] - batch_size: int, - length_struct: Any, # PyTree[int] - max_segments: int, - ): - self._batch_size = batch_size - self._length_struct = length_struct - self._max_segments = max_segments - - # Define the main buffers we will pack the data into. - def make_packed_buffer(length: int, input_arr: np.ndarray): - return np.zeros( - shape=(batch_size, length, *input_arr.shape[1:]), # (B, T, ...) - dtype=input_arr.dtype, - ) - - self._batch = tree_lib.map_structure(make_packed_buffer, length_struct, element_for_shapes) - - def make_packed_aux_info(length: int): - return np.zeros(shape=(batch_size, length), dtype=np.int32) - - self._segmentations = tree_lib.map_structure(make_packed_aux_info, length_struct) - self._positions = tree_lib.map_structure(make_packed_aux_info, length_struct) - - # Tracks the next empty position to insert an example for each row - # in the batch, for each feature in features_to_pack. - self._first_free_cell_per_row = tree_lib.map_structure( - lambda _: np.zeros(batch_size, dtype=np.int32), length_struct - ) - - # Tracks the number of examples already packed into row of the batch. Used - # to fill the segmentation values for each feature. - self._num_examples_per_row = [0 for _ in range(batch_size)] - - # For determinism, the metadata.index for the packed batch must match - # metadata.index of the _last_ included input example. - self._last_record_metadata = None - - def get_packed_batch(self) -> record.Record[tuple[_T, _T, _T]]: - assert self._last_record_metadata is not None - return record.Record( - metadata=cast(record.RecordMetadata, self._last_record_metadata), - data=(self._batch, self._segmentations, self._positions), - ) - - def _can_add_at_row( - self, - element: Any, # PyTree[np.ndarray] - ) -> int: - """Returns the index of the first row which fits element, or -1 if none.""" - element_feature_lengths = tree_lib.map_structure(len, element) - - # Check no feature exceeds max length - length_exceeded = tree_lib.map_structure( - lambda feature_length, max_length: feature_length > max_length, - element_feature_lengths, - self._length_struct, - ) - if any(tree_lib.flatten(length_exceeded)): - raise ValueError("Inputs to PackAndBatchOperation must be truncated to max length.") - - # For each row, check whether the total length after adding the current - # element would exceed max feature lengths. - def _feature_will_fit(feature_length, first_free_cell, max_length): - return feature_length + first_free_cell <= max_length - - is_row_free_struct = tree_lib.map_structure( - _feature_will_fit, element_feature_lengths, self._first_free_cell_per_row, self._length_struct - ) - - ## Pick first row (if exists) where element can be added. - for i in range(self._batch_size): - if self._num_examples_per_row[i] < self._max_segments: - row_is_free_per_feature = [free[i] for free in tree_lib.flatten(is_row_free_struct)] - if all(row_is_free_per_feature): - return i - return -1 - - def add_element_to_batch( - self, - element: Any, # PyTree[np.ndarray] - row: int, - ) -> None: - """Adds element to current batch at the specified row.""" - # Apply updates to each feature. - for per_feature_data in zip( - tree_lib.flatten(element), - tree_lib.flatten(self._batch), - tree_lib.flatten(self._segmentations), - tree_lib.flatten(self._positions), - tree_lib.flatten(self._first_free_cell_per_row), - ): - value, batch_value, segmentations, positions, first_free_cell_per_row = per_feature_data - # Update batch value, segmentations, and positions. - start = first_free_cell_per_row[row] - end = first_free_cell_per_row[row] + len(value) - batch_value[row][start:end] = value - segmentations[row][start:end] = self._num_examples_per_row[row] + 1 - positions[row][start:end] = np.arange(end - start) - # Update first_free_cell_per_row. - first_free_cell_per_row[row] += len(value) - - self._num_examples_per_row[row] += 1 - - def try_add_to_batch(self, element: record.Record) -> bool: - """Finds a row in the batch at which element can be added.""" - if (row_idx := self._can_add_at_row(element.data)) == -1: - return False - self.add_element_to_batch(element.data, row_idx) - self._last_record_metadata = element.metadata.remove_record_key() - return True - - -@dataclasses.dataclass -class CustomPackAndBatchOperation(Generic[_T]): - """PyGrain pack-and-batch operation - see module docstring. - - WARNING: This class is deprecated. Please use - lazy_dataset.FirstFitPackIterDataset instead. - - Attributes: - batch_size: int, the batch size. - length_struct: A pytree, with the same structure as `input_iterator` - elements, but where leaves are ints, representing the packed length of the - corresponding feature. - max_segments: int, max segments per sequence - - __call__() takes an input iterator, where elements are `Record`s containing: - - input_data: Pytrees of arrays. For more info about PyTrees, please refer to: - https://jax.readthedocs.io/en/latest/pytrees.html. Packed leaves should be - n-dimensional arrays, with sequence length as the leading dimension, i.e. - shape (T_in, ...), where T_in < T_packed. Note that leaves can and will - often have ragged length dimensions across different elements of the input - iterator. - - The output of __call__() will be an iterator over `Record`s containing a - 3-tuple of Pytrees. These are: - - data: The batched and packed data. This is a Pytree with parallel structure - to elements of `input_iterator`. Leaves have shape (B, T_packed, ...). - segmentations: Pytree with the same structure as `data`, and leaves of shape - (B, T). Represents which example each entry comes from. This may be used - for Transformer attention masks, for example. - positions: Pytree with the same structure as `data`, and leaves of shape - (B, T). Represents the position of each entry within their original - example. This may be used e.g. in Transformer absolute position - embeddings. - """ - - length_struct: Any # PyTree[int] - batch_size: int - max_segments: int - # We don't know input shapes and corresponding buffer shapes until __call__. - _cur_batch: Union[_PackedBatch, None] = None - - def __post_init__(self): - logging.error( - "PackAndBatchOperation is deprecated. Please use" " lazy_dataset.FirstFitPackIterDataset instead." - ) - - def __call__( - self, input_iterator: Iterator[record.Record[_T]] - ) -> Iterator[record.Record[tuple[_T, _T, _T]]]: - for element in input_iterator: - # Use `element` to set dtypes + trailing dimensions. - if self._cur_batch is None: # pytype: disable=attribute-error - self._cur_batch = _PackedBatch( - element.data, self.batch_size, self.length_struct, self.max_segments - ) - - # Try adding element to the current packed batch. - element_added_to_batch = self._cur_batch.try_add_to_batch(element) - - # When we have a full batch, yield the current packed data, - # and then start a new batch with this element. - if not element_added_to_batch: - yield self._cur_batch.get_packed_batch() # Main yield - self._cur_batch = _PackedBatch( - element.data, self.batch_size, self.length_struct, self.max_segments - ) - self._cur_batch.try_add_to_batch(element) - - # Final batch - yield self._cur_batch.get_packed_batch() diff --git a/primus/backends/maxtext/layers/attention_op.py b/primus/backends/maxtext/layers/attention_op.py index 5cdd9aee6..c10a0a2ad 100644 --- a/primus/backends/maxtext/layers/attention_op.py +++ b/primus/backends/maxtext/layers/attention_op.py @@ -6,7 +6,13 @@ ############################################################################### import jax.numpy as jnp -from MaxText.common_types import MODEL_MODE_TRAIN, Array, AttentionType +from MaxText.common_types import ( + DEFAULT_MASK_VALUE, + MODEL_MODE_TRAIN, + Array, + AttentionType, +) +from MaxText.layers import nnx_wrappers from MaxText.layers.attention_op import AttentionOp @@ -23,11 +29,14 @@ def cudnn_flash_attention( model_mode: str = MODEL_MODE_TRAIN, ) -> Array: """CUDNN Flash Attention with Transformer Engine. - 1. Stable API, supports GQA, SWA (only with causal masking) - 2. Head_dim = 256 is also supported from TE-1.12 stable release with CUDNN 12.6 + 1. Stable API, supports MHA, GQA, SWA, Packing and Context Parallelism + 2. Context Parallelism currently only supports causal masking and no packing """ # These imports are only meant to work in a GPU build. # pylint: disable=import-outside-toplevel + from transformer_engine.jax.attention import ( + SequenceDescriptor, # pytype: disable=import-error + ) from transformer_engine.jax.flax.transformer import ( DotProductAttention, # pytype: disable=import-error ) @@ -36,34 +45,47 @@ def cudnn_flash_attention( using_context_parallelism = self.mesh.shape["context"] > 1 - if self.attention_type == AttentionType.LOCAL_SLIDING and using_context_parallelism: - raise AssertionError( - "Sliding window attention is not supported when context parallelism is enabled" - ) - + # Initialize default attention configuration sliding_window_size = None mask_type = "padding_causal" - qkv_layout = "BSHD_BSHD_BSHD" # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + qkv_layout = "BSHD_BSHD_BSHD" # Non-packed format: 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' max_segments_per_seq = 1 # max number of segments per sequence; for non-packed its 1 + attn_mask_threshold = 0.5 + # Handle local sliding window attention if configured if self.attention_type == AttentionType.LOCAL_SLIDING: sliding_window_size = [self.sliding_window_size, 0] + # Handle packing configurations if self.config.packing and self.config.dataset_type != "synthetic": + qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD' if decoder_segment_ids is None: decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) attn_mask = SequenceDescriptor.from_segment_ids_and_pos( segment_ids=decoder_segment_ids, segment_pos=None ) - qkv_layout = "THD_THD_THD" # 'T3HD', 'THD_T2HD' or 'THD_THD_THD' - max_segments_per_seq = 32 - elif ( - using_context_parallelism or self.config.dataset_type == "synthetic" - ): # context parallelism currently only supports causal masking and no packing + # Create dummy SequenceDescriptor for lazy_init + dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) + dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=dummy_segment_ids, segment_pos=None + ) + max_segments_per_seq = self.config.max_segments_per_seq + elif using_context_parallelism or self.config.dataset_type == "synthetic": + if self.attention_type == AttentionType.LOCAL_SLIDING: + raise AssertionError("Sliding window attention is not supported for context parallelism") + # Context parallelism without packing: only supports causal masking attn_mask = None + dummy_attn_mask = None mask_type = "causal" else: + # Default case: no packing, no context parallelism + dummy_attn_mask = jnp.zeros( + (1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8 + ) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * attn_mask_threshold), 0, 1).astype( + jnp.uint8 + ) dpa_layer = DotProductAttention( head_dim=head_dim, @@ -76,11 +98,27 @@ def cudnn_flash_attention( dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout=qkv_layout, - # scale_factor=1.0, + scale_factor=1.0, transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, context_parallel_axis="context", + # context_parallel_strategy=self.config.context_parallel_strategy, max_segments_per_seq=max_segments_per_seq, ) - return dpa_layer(query, key, value, mask=attn_mask) + + dpa_layer = nnx_wrappers.ToNNX(dpa_layer, rngs=self.rngs) + dummy_query_prefill = jnp.zeros( + (1, self.max_target_length, self.num_query_heads, self.config.head_dim), dtype=self.dtype + ) + dummy_key_prefill = jnp.zeros( + (1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype + ) + dummy_value_prefill = jnp.zeros( + (1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype + ) + + dpa_layer.lazy_init( + dummy_query_prefill, dummy_key_prefill, dummy_value_prefill, sequence_descriptor=dummy_attn_mask + ) + return dpa_layer(query, key, value, sequence_descriptor=attn_mask) diff --git a/primus/backends/maxtext/layers/attentions.py b/primus/backends/maxtext/layers/attentions.py deleted file mode 100644 index ceb880d84..000000000 --- a/primus/backends/maxtext/layers/attentions.py +++ /dev/null @@ -1,49 +0,0 @@ -############################################################################### -# Copyright 2023–2025 Google LLC. All rights reserved. -# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -############################################################################### - -from typing import Tuple - -from flax import nnx -from MaxText.layers.attentions import Attention -from MaxText.layers.linears import DenseGeneral - - -class PrimusAttention(Attention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def init_query_w(self, inputs_q_shape: Tuple) -> nnx.Module: - """Query projection initialization.""" - - # NOTE: T5 does not explicitly rescale the attention logits by - # 1/sqrt(depth_kq)! This is folded into the initializers of the - # linear transformations, which is equivalent under Adafactor. - # depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) - depth_scaling = 1.0 - - def query_init(*args): - # pylint: disable=no-value-for-parameter - return self.kernel_init(*args) / depth_scaling - - kernel_axes = ( - (None, None, None) - if self.config.ici_context_autoregressive_parallelism > 1 - else ("embed", "q_heads", "kv") - ) - return DenseGeneral( - in_features_shape=self.convert_dense_general_inputs_shape(inputs_q_shape), - out_features_shape=(self.num_query_heads, self.head_dim), - axis=-1, - kernel_init=query_init, - kernel_axes=kernel_axes, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - quant=self.quant, - matmul_precision=self.config.matmul_precision, - use_bias=self.use_bias_in_projections, - rngs=self.rngs, - ) diff --git a/primus/backends/maxtext/layers/gemma.py b/primus/backends/maxtext/layers/gemma.py new file mode 100644 index 000000000..400d216e1 --- /dev/null +++ b/primus/backends/maxtext/layers/gemma.py @@ -0,0 +1,98 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from typing import Optional + +from flax import nnx +from jax.sharding import Mesh +from MaxText import max_utils +from MaxText.common_types import Config +from MaxText.layers import quantizations +from MaxText.layers.attentions import Attention +from MaxText.layers.gemma import GemmaDecoderLayer +from MaxText.layers.linears import Dropout, MlpBlock +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant + + +class PrimusGemmaDecoderLayer(GemmaDecoderLayer): + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: Optional[Quant] = None, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=self.mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=self.model_mode, + rngs=self.rngs, + ) + + self.pre_ffw_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.mlp = MlpBlock( + config=config, + mesh=self.mesh, + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + quant=self.quant, + model_mode=self.model_mode, + rngs=self.rngs, + ) + + self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") diff --git a/primus/backends/maxtext/layers/gemma2.py b/primus/backends/maxtext/layers/gemma2.py new file mode 100644 index 000000000..72c0fb1b9 --- /dev/null +++ b/primus/backends/maxtext/layers/gemma2.py @@ -0,0 +1,197 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from typing import Optional + +from flax import nnx +from jax.sharding import Mesh +from MaxText import max_utils +from MaxText.common_types import MODEL_MODE_PREFILL, Config +from MaxText.layers import quantizations +from MaxText.layers.attentions import Attention, AttentionType +from MaxText.layers.gemma2 import Gemma2DecoderLayer +from MaxText.layers.linears import Dropout, MlpBlock +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant + + +class PrimusGemma2DecoderLayer(Gemma2DecoderLayer): + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: Optional[Quant] = None, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_norm_local = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.self_attention_local = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=self.mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + attention_type=AttentionType.LOCAL_SLIDING, + sliding_window_size=config.sliding_window_size, + attn_logits_soft_cap=config.attn_logits_soft_cap, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=self.model_mode, + rngs=self.rngs, + ) + + if config.use_post_attn_norm: + self.post_self_attention_norm_local = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.pre_ffw_norm_local = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.mlp_local = MlpBlock( + config=config, + mesh=self.mesh, + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + quant=self.quant, + model_mode=self.model_mode, + rngs=self.rngs, + ) + + if config.use_post_ffw_norm: + self.post_ffw_norm_local = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + + self.pre_self_attention_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.self_attention_global = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=self.mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=True, + float32_logits=True, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + attention_type=AttentionType.GLOBAL, + attn_logits_soft_cap=config.attn_logits_soft_cap, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=model_mode, + rngs=self.rngs, + ) + + if config.use_post_attn_norm: + self.post_self_attention_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.pre_ffw_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + self.mlp_global = MlpBlock( + config=config, + mesh=self.mesh, + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + quant=self.quant, + model_mode=self.model_mode, + rngs=self.rngs, + ) + + if config.use_post_ffw_norm: + self.post_ffw_norm_global = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + rngs=self.rngs, + ) + + if model_mode == MODEL_MODE_PREFILL: + self.activation_axis_names = ( + "activation_batch", + "prefill_activation_norm_length", + "activation_embed", + ) + else: + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") diff --git a/primus/backends/maxtext/layers/llama2.py b/primus/backends/maxtext/layers/llama2.py new file mode 100644 index 000000000..5225a2e55 --- /dev/null +++ b/primus/backends/maxtext/layers/llama2.py @@ -0,0 +1,118 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +import functools + +from flax import nnx +from jax.sharding import Mesh +from MaxText import max_utils +from MaxText.common_types import MODEL_MODE_PREFILL, Config +from MaxText.layers import quantizations +from MaxText.layers.attentions import Attention +from MaxText.layers.linears import Dropout, MlpBlock +from MaxText.layers.llama2 import LlamaDecoderLayer +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant +from MaxText.sharding import maybe_shard_with_logical + + +class PrimusLlamaDecoderLayer(LlamaDecoderLayer): + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: None | Quant = None, + ): + + self.config = config + self.mesh = mesh + self.quant = quant + + if model_mode == MODEL_MODE_PREFILL: + self.activation_axis_names = ( + "activation_batch", + "prefill_activation_norm_length", + "activation_embed", + ) + else: + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + shard_mode=config.shard_mode, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), + reshape_q=config.reshape_q, + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=model_mode, + rngs=rngs, + ) + + self.post_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + shard_mode=config.shard_mode, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + self.mlp = MlpBlock( + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + config=config, + mesh=mesh, + quant=self.quant, + model_mode=model_mode, + rngs=rngs, + ) + + self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + + self._maybe_shard_with_logical = functools.partial( + maybe_shard_with_logical, + mesh=self.mesh, + shard_mode=config.shard_mode, + ) diff --git a/primus/backends/maxtext/layers/mistral.py b/primus/backends/maxtext/layers/mistral.py new file mode 100644 index 000000000..64231a969 --- /dev/null +++ b/primus/backends/maxtext/layers/mistral.py @@ -0,0 +1,101 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from flax import nnx +from jax.sharding import Mesh +from MaxText import max_utils +from MaxText.common_types import Config +from MaxText.layers import quantizations +from MaxText.layers.attentions import Attention +from MaxText.layers.linears import Dropout, MlpBlock +from MaxText.layers.mistral import MistralDecoderLayer +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant + + +class PrimusMistralDecoderLayer(MistralDecoderLayer): + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + *, + rngs: nnx.Rngs, + quant: None | Quant = None, + ): + self.config = config + self.mesh = mesh + self.quant = quant + self.rngs = rngs + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=self.rngs, + ) + + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), + reshape_q=config.reshape_q, + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=model_mode, + rngs=self.rngs, + ) + + self.post_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=self.rngs, + ) + + self.mlp = MlpBlock( + mesh=self.mesh, + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + config=config, + quant=self.quant, + model_mode=model_mode, + rngs=self.rngs, + ) + + self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") diff --git a/primus/backends/maxtext/layers/mixtral.py b/primus/backends/maxtext/layers/mixtral.py new file mode 100644 index 000000000..997bc61c3 --- /dev/null +++ b/primus/backends/maxtext/layers/mixtral.py @@ -0,0 +1,104 @@ +############################################################################### +# Copyright 2023–2025 Google LLC. All rights reserved. +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from flax import linen as nn +from flax import nnx +from jax.sharding import Mesh +from MaxText import max_utils +from MaxText.common_types import Config +from MaxText.layers import initializers, moe, quantizations +from MaxText.layers.attentions import Attention +from MaxText.layers.linears import Dropout +from MaxText.layers.mixtral import MixtralDecoderLayer +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant + + +class PrimusMixtralDecoderLayer(MixtralDecoderLayer): + @nn.compact + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant = None, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=self.rngs, + ) + + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), + reshape_q=config.reshape_q, + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), + model_mode=model_mode, + rngs=self.rngs, + ) + + self.post_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=self.rngs, + ) + + self.MoeBlock_0 = moe.RoutedMoE( + config=config, + num_experts=config.num_experts, + num_experts_per_tok=config.num_experts_per_tok, + mesh=mesh, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + intermediate_dim=config.mlp_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + quant=self.quant, + rngs=self.rngs, + ) + + self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") diff --git a/primus/backends/maxtext/layers/moe.py b/primus/backends/maxtext/layers/moe.py index 10b0ea36c..49bab61ac 100644 --- a/primus/backends/maxtext/layers/moe.py +++ b/primus/backends/maxtext/layers/moe.py @@ -59,6 +59,9 @@ def dense_matmul( w0_kernel, w1_kernel, wo_kernel, + w0_bias, + w1_bias, + wo_bias, ) -> tuple[jax.Array, Optional[jax.Array]]: """Dense matrix multiplication.""" if self.config.expert_balance: @@ -82,7 +85,9 @@ def dense_matmul( gate_logits = jnp.broadcast_to(rotated_weights[None, :, :], (batch_size, seq_len, num_experts)) ############################################# end #################################################### ########################################## - return super().dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel) + return super().dense_matmul( + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + ) def sparse_matmul( self, @@ -92,6 +97,9 @@ def sparse_matmul( w0_kernel, w1_kernel, wo_kernel, + w0_bias, + w1_bias, + wo_bias, ): """Perform sparse matrix multiplication with optional Primus Turbo backend.""" if not self.config.use_turbo_grouped_gemm: @@ -109,7 +117,15 @@ def sparse_matmul( # Fallback to original implementation if primus_turbo is not available max_logging.log("WARNING: primus_turbo not available, using default ragged_dot in MoE") return super().sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, ) max_logging.log("Using primus_turbo grouped_gemm in MoE") @@ -129,7 +145,15 @@ def _turbo_ragged_dot(*, lhs, rhs, group_sizes, preferred_element_type=None, **k jax.lax.ragged_dot = _turbo_ragged_dot try: return super().sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, ) finally: jax.lax.ragged_dot = _orig_ragged_dot diff --git a/primus/backends/maxtext/max_utils.py b/primus/backends/maxtext/max_utils.py index 84ed82253..e5a048ed5 100644 --- a/primus/backends/maxtext/max_utils.py +++ b/primus/backends/maxtext/max_utils.py @@ -10,7 +10,142 @@ import socket import jax +import orbax.checkpoint as ocp from MaxText import max_logging +from MaxText.max_utils import ( + _retrieve_jax_init_info, + get_coordinator_ip_address, + is_cpu_backend, + is_gpu_backend, +) +from orbax.checkpoint.experimental.emergency.multi_tier_checkpointing import ( + initialization, +) + + +def maybe_initialize_jax_distributed_system(raw_keys): + """The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of + indirection in MaxText to avoid breaking the call sites unnecessarily. + + Currently jax.distributed.initialize() fully works as expected! + + For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. + """ + if raw_keys["skip_jax_distributed_system"]: + max_logging.log("Skipping jax distributed system due to skip_jax_distributed_system=True flag.") + return + if raw_keys["enable_single_controller"]: + max_logging.log("Skipping jax distributed system since its not needed for single controller.") + return + if jax.distributed.is_initialized(): + max_logging.log("Jax distributed system is already initialized.") + return + if raw_keys["inference_benchmark_test"]: + # Disable initialization for inference benmark test. + return + if raw_keys["compile_topology"]: + # Don't initialize jax distributed with AOT compilation + return + if is_gpu_backend(raw_keys): + max_logging.log("Attempting to initialize the jax distributed system for GPU backend...") + initialize_jax_for_gpu(raw_keys) + max_logging.log("Jax distributed system initialized on GPU!") + elif is_cpu_backend(raw_keys): + max_logging.log("Attempting to initialize the jax distributed system for CPU backend...") + initialize_jax_for_cpu(raw_keys) + max_logging.log("Jax distributed system initialized on CPUs!") + elif raw_keys["enable_multi_tier_checkpointing"]: + max_logging.log( + "Attempting to initialize the jax distributed system for multi-tier " "checkpointing..." + ) + initialization.initialize_multi_tier_checkpointing( + local_checkpoint_directory=raw_keys["local_checkpoint_directory"], + backup_interval_minutes=raw_keys["multi_tier_checkpointing_backup_interval_minutes"], + run_name=raw_keys["run_name"], + jax_initialization_timeout_seconds=raw_keys["jax_distributed_initialization_timeout"], + data_parallelism=raw_keys["mtc_data_parallelism"], + ) + max_logging.log("Jax distributed system initialized for multi-tier checkpointing!") + elif (raw_keys["enable_checkpointing"] and raw_keys["compile_topology_num_slices"] == -1) or raw_keys[ + "hardware" + ] == "gpu_multiprocess": + max_logging.log("Attempting to initialize the jax distributed system...") + if not raw_keys["enable_emergency_checkpoint"]: + jax.distributed.initialize( + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) + else: + if raw_keys["hardware"] == "gpu_multiprocess": + max_logging.log("Initializing jax distribtued to support local checkpointing with" " GPUs...") + jax.distributed.initialize( + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) + ocp.multihost.initialize_runtime_to_distributed_ids() + ocp.multihost.initialize_distributed_to_device_ids() + else: + initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys) + max_logging.log("Jax distributed system initialized!") + + +def initialize_jax_for_gpu(raw_keys): + """Jax distributed initialize for GPUs.""" + if os.environ.get("JAX_COORDINATOR_IP") is not None: + coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) + coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) + jax.distributed.initialize( + coordinator_address=f"{coordinator_ip}:{coordinator_port}", + num_processes=int(os.getenv("NNODES")), + process_id=int(os.getenv("NODE_RANK")), + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) + max_logging.log(f"JAX global devices: {jax.devices()}") + + +def initialize_jax_for_cpu(raw_keys): + """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready.""" + coordinator_ip_address = get_coordinator_ip_address() + coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK + # Env variables to be set in XPK or otherwise + job_index = int(os.environ.get("JOB_INDEX")) + job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) + processes_in_job = int(os.environ.get("PROCESSES_IN_JOB")) + pid = job_index * processes_in_job + job_completion_index + max_logging.log(f" Jax process id is {pid} ") + # Explicit initialize is needed only for CPUs + jax.distributed.initialize( + coordinator_address=coordinator_address, + process_id=pid, + num_processes=int(os.environ.get("JAX_PROCESS_COUNT")), + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) + + +def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): + """Initialize JAX distributed runtime for TPUs when emergency checkpointing is used. + The information required to initialize JAX distributed runtime will be written by GKE to + the local checkpoint directory. This function retrieves that information and initializes + JAX distributed runtime. + """ + process_id, coordinator_address = _retrieve_jax_init_info(raw_keys) + + if process_id != "" and coordinator_address != "": + max_logging.log( + f"Using {process_id} as the process_id and {coordinator_address} as the" + " coordinator_address to initialize JAX distributed runtime..." + ) + jax.distributed.initialize( + coordinator_address=coordinator_address, + process_id=int(process_id), + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) + + ocp.multihost.initialize_runtime_to_distributed_ids() + ocp.multihost.initialize_distributed_to_device_ids() def print_system_information(): @@ -55,38 +190,19 @@ def initialize_wandb_writer(config): if jax.process_index() != 0 or not config.enable_wandb: return None - def safe_get_config(config, key, default=None): - try: - return getattr(config, key) - except KeyError: - return default - import wandb - if safe_get_config(config, "wandb_save_dir") is None or config.wandb_save_dir == "": - wandb_save_dir = os.path.join(config.base_output_directory, "wandb") - else: - wandb_save_dir = config.wandb_save_dir - - if safe_get_config(config, "wandb_project") is None or config.wandb_project == "": - wandb_project = os.getenv("WANDB_PROJECT", "Primus-MaxText-Pretrain") - else: - wandb_project = config.wandb_project - if safe_get_config(config, "wandb_exp_name") is None or config.wandb_exp_name == "": - wandb_exp_name = config.run_name - else: - wandb_exp_name = config.wandb_exp_name - - if config.enable_wandb and "WANDB_API_KEY" not in os.environ: - max_logging.log( - "The environment variable WANDB_API_KEY is not set. Please set it or login wandb before proceeding" - ) - return None - - os.makedirs(wandb_save_dir, exist_ok=True) - - wandb.init(project=wandb_project, name=wandb_exp_name, dir=wandb_save_dir, config=dict(config.get_keys())) - max_logging.log(f"WandB logging enabled: {wandb_save_dir=}, {wandb_project=}, {wandb_exp_name=}") + os.makedirs(config.wandb_save_dir, exist_ok=True) + + wandb.init( + project=config.wandb_project, + name=config.wandb_exp_name, + dir=config.wandb_save_dir, + config=dict(config.get_keys()), + ) + max_logging.log( + f"WandB logging enabled: {config.wandb_save_dir=}, {config.wandb_project=}, {config.wandb_exp_name=}" + ) return wandb diff --git a/primus/backends/maxtext/metric_logger.py b/primus/backends/maxtext/metric_logger.py index f249bae49..8a707808d 100644 --- a/primus/backends/maxtext/metric_logger.py +++ b/primus/backends/maxtext/metric_logger.py @@ -10,7 +10,7 @@ import jax import numpy as np from MaxText import max_logging, max_utils, maxtext_utils -from MaxText.metric_logger import MetricLogger +from MaxText.metric_logger import MetadataKey, MetricLogger from .max_utils import close_wandb_writer, initialize_wandb_writer @@ -52,10 +52,12 @@ def write_metrics_to_wandb(self, metrics, step, is_training): def write_setup_info_to_tensorboard(self, params): """Writes setup information like train config params, num model params, and XLA flags to TensorBoard.""" num_model_parameters = max_utils.calculate_num_params_from_pytree(params) - self.metadata["per_device_tflops"], _, _ = maxtext_utils.calculate_tflops_training_per_device( + self.metadata[MetadataKey.PER_DEVICE_TFLOPS], _, _ = ( + maxtext_utils.calculate_tflops_training_per_device(self.config) + ) + self.metadata[MetadataKey.PER_DEVICE_TOKENS] = maxtext_utils.calculate_tokens_training_per_device( self.config ) - self.metadata["per_device_tokens"] = maxtext_utils.calculate_tokens_training_per_device(self.config) max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), self.writer) max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], self.writer) diff --git a/primus/backends/maxtext/train.py b/primus/backends/maxtext/train.py index 7aee03126..ac918c64e 100644 --- a/primus/backends/maxtext/train.py +++ b/primus/backends/maxtext/train.py @@ -29,18 +29,19 @@ maxtext_utils, profiler, pyconfig, + sharding, train_utils, ) -from MaxText.data_loader import DataLoader +from MaxText.common_types import ShardMode from MaxText.metric_logger import MetricLogger from MaxText.train import ( _merge_dpo_state, _split_dpo_state, eval_step, get_first_step, - setup_train_loop, train_step, ) +from MaxText.train_utils import validate_train_config from MaxText.utils import gcs_utils from MaxText.utils.goodput_utils import ( GoodputEvent, @@ -51,33 +52,6 @@ from MaxText.vertex_tensorboard import VertexTensorboardManager -def validate_train_config(config): - """Validates the configuration is set correctly for 'train.py'.""" - - assert config.run_name, "Erroring out, need a real run_name" - if config.dataset_path and not config.dataset_path.startswith("gs://"): - max_logging.log("WARNING: 'dataset_path' might be pointing your local file system") - if not config.base_output_directory.startswith("gs://"): - max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") - assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer." - - if config.quantization in ("fp8", "nanoo_fp8"): - # pylint: disable=line-too-long - assert config.gradient_accumulation_steps == 1, ( - "fp8 can't be used with gradient_accumulation_steps right now. Please use other quantization or set " - "gradient_accumulation_steps to 1" - ) - - # Check if GPU Flash Attention is being used with sequence packing - # if config.attention == "cudnn_flash_te" and config.packing and config.dataset_type != "synthetic": - # raise ValueError( - # "cudnn_flash_te only supports BSHD format. The THD (seq packing) support is going to be available in " - # "Transformer Engine 2.0 release. " - # "Please disable sequence packing (set packing=False) or use a different attention mechanism. " - # "With synthetic data, the format is not important as packing is not applied." - # ) - - def train_loop(config, recorder, state=None): """Main Training loop.""" ( @@ -88,9 +62,11 @@ def train_loop(config, recorder, state=None): mesh, learning_rate_schedule, data_iterator, + data_loader, + rampup_manager, eval_data_iterator, state, - ) = setup_train_loop(config, recorder) + ) = train_utils.setup_train_loop(config, recorder) if config.use_dpo: if "reference_params" not in state.params: @@ -98,35 +74,69 @@ def train_loop(config, recorder, state=None): state = _merge_dpo_state(state, reference_params) state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt( + config, state_mesh_shardings + ) + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator + config, + model, + mesh, + state, + state_mesh_shardings, + train_step, + eval_step, + eval_data_iterator, + params_shardings, ) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): shaped_batch = maxtext_utils.get_shaped_batch(config) - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() - compiled_stats = compiled.memory_analysis() - max_utils.print_compiled_memory_stats(compiled_stats) + if config.shard_optimizer_over_data: + state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) + if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded + compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled_stats = compiled.memory_analysis() + max_utils.print_compiled_memory_stats(compiled_stats) start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) - data_loader = DataLoader(config, mesh, data_iterator, recorder) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard metric_logger.write_setup_info_to_tensorboard(state.params) + # Synchronize all hosts before entering the training loop. + # Without this barrier, timing variance during initialization (JIT compilation, + # profiler/logger setup, etc.) causes hosts to enter the training loop at different + # times. The first collective operation (data sharding in load_next_batch) then + # times out waiting for straggler hosts, resulting in "collective operation timeout" + # or "stop sending heartbeats" errors. + max_logging.log("====== BARRIER: Synchronizing hosts before training loop ======") + jax.experimental.multihost_utils.sync_global_devices("sync_before_training_loop") + max_logging.log("====== BARRIER PASSED: Starting training loop ======") + try: last_step_completion = datetime.datetime.now() for step in np.arange(start_step, config.steps): prof.maybe_activate_profiler(step, state) with jax.profiler.StepTraceAnnotation("train", step_num=step): - example_batch = data_loader.load_next_batch() + example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) + # Reshard data from loaded sharding to performant activation sharding + example_batch = sharding.maybe_shard_with_name( + example_batch, + sharding.get_input_data_sharding(config, mesh), + shard_mode=config.shard_mode, + ) # pylint: disable=not-callable nextrng = jax.jit(jax.random.fold_in)(init_rng, step) with maybe_record_goodput(recorder, GoodputEvent.STEP, step): with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + if config.shard_optimizer_over_data: + state = sharding.maybe_shard_with_name( + state, state_mesh_shardings, config.shard_mode + ) state, metrics = p_train_step(state, example_batch, nextrng) jax.block_until_ready(state) @@ -150,8 +160,8 @@ def train_loop(config, recorder, state=None): if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: assert eval_data_iterator - - # Explicitly reset the eval counters before starting the eval loop + # Explicitly reset the eval iterator and counters before starting the eval loop + eval_data_iterator.reset() metric_logger.reset_eval_metrics() eval_step_count = 0 @@ -179,8 +189,12 @@ def train_loop(config, recorder, state=None): metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) + if config.save_checkpoint_on_completion: + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) + if checkpoint_manager is not None: + # in case the last checkpoint_period checkpoint is still in progress + checkpoint_manager.wait_until_finished() except exceptions.StopTraining as e: max_logging.log(f"Training stopped: {str(e)}") finally: @@ -204,17 +218,19 @@ def initialize(argv: Sequence[str], **kwargs) -> tuple[pyconfig.HyperParameters, # TODO: mazumdera@ : ensure missing mandatory fields in base.yml are filled in in argv, # or fill in here config = pyconfig.initialize(argv, **kwargs) - jax.config.update("jax_use_shardy_partitioner", config.shardy) max_utils.print_system_information() validate_train_config(config) max_utils.save_device_information(config) + jax.config.update("jax_use_shardy_partitioner", config.shardy) + # update explicit sharding-supported config + if config.shard_mode == ShardMode.EXPLICIT: + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) os.environ["TFDS_DATA_DIR"] = config.dataset_path or "" vertex_tensorboard_manager = VertexTensorboardManager() if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): vertex_tensorboard_manager.configure_vertex_tensorboard(config) - # Goodput configurations - maybe_monitor_goodput(config) + # Create the Goodput recorder recorder = create_goodput_recorder(config) # Stack traces configurations @@ -231,6 +247,10 @@ def initialize(argv: Sequence[str], **kwargs) -> tuple[pyconfig.HyperParameters, def run(config, recorder, diagnostic_config): """Run the job given hyperparameters and utilities""" - with diagnostic.diagnose(diagnostic_config): - with maybe_record_goodput(recorder, GoodputEvent.JOB): - train_loop(config, recorder) + with ( + diagnostic.diagnose(diagnostic_config), + maybe_record_goodput(recorder, GoodputEvent.JOB), + max_utils.maybe_get_transformer_engine_context(config), + maybe_monitor_goodput(config), + ): + train_loop(config, recorder) diff --git a/primus/backends/maxtext/train_utils.py b/primus/backends/maxtext/train_utils.py index 42641caaa..bfd70b1f2 100644 --- a/primus/backends/maxtext/train_utils.py +++ b/primus/backends/maxtext/train_utils.py @@ -15,26 +15,25 @@ def create_training_tools(config, model, mesh): learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) logger = checkpointing.setup_checkpoint_logger(config) - if config.enable_emergency_checkpoint: - if config.use_replicator_service: - checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager( - config.local_checkpoint_directory, - config.local_checkpoint_period, - mesh, - ) - else: - abstract_state, _, _ = maxtext_utils.get_abstract_state( - model, tx, config, init_rng, mesh, is_training=True - ) - checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( - config.local_checkpoint_directory, - config.checkpoint_dir, - mesh, - abstract_state, - config.local_checkpoint_period, - config.checkpoint_period, - logger, - ) + if config.enable_multi_tier_checkpointing: + checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager( + config.local_checkpoint_directory, + config.local_checkpoint_period, + mesh, + ) + elif config.enable_emergency_checkpoint: + abstract_state, _, _ = maxtext_utils.get_abstract_state( + model, tx, config, init_rng, mesh, is_training=True + ) + checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( + config.local_checkpoint_directory, + config.checkpoint_dir, + mesh, + abstract_state, + config.local_checkpoint_period, + config.checkpoint_period, + logger, + ) else: # TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend use_ocdbt = config.checkpoint_storage_use_ocdbt @@ -54,7 +53,7 @@ def create_training_tools(config, model, mesh): logger, use_ocdbt, use_zarr3, - config.max_to_keep, + config.max_num_checkpoints_to_keep, ) return init_rng, checkpoint_manager, learning_rate_schedule, tx diff --git a/primus/configs/models/maxtext/llama3.1_405B.yaml b/primus/configs/models/maxtext/llama3.1_405B.yaml new file mode 100644 index 000000000..1519b43e9 --- /dev/null +++ b/primus/configs/models/maxtext/llama3.1_405B.yaml @@ -0,0 +1,7 @@ +extends: + - model_base.yaml + +model_name: "llama3.1-405b" +tokenizer_path: "meta-llama/Llama-3.3-70B-Instruct" +attention: "cudnn_flash_te" +use_iota_embed: true diff --git a/primus/configs/modules/maxtext/trainer_base.yaml b/primus/configs/modules/maxtext/trainer_base.yaml index f7700b1cb..be81b358e 100644 --- a/primus/configs/modules/maxtext/trainer_base.yaml +++ b/primus/configs/modules/maxtext/trainer_base.yaml @@ -48,7 +48,6 @@ async_checkpointing: true checkpoint_period: 10_000 # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: false -max_to_keep: 5 force_unroll: false # during generate_param_only_checkpoint should we unroll the loop? @@ -121,9 +120,11 @@ save_quantized_params_path: "" model_call_mode: "" use_qwix_quantization: false # Whether to use qwix for quantization. If set to True, the model will be quantized using qwix. # Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 -quantization_calibration_method: "absmax" # Shard the range finding operation for quantization. By default this is set to number of slices. quantization_local_shard_count: -1 +act_quantization_calibration_method: absmax +bwd_quantization_calibration_method: absmax +weight_quantization_calibration_method: absmax # Global parameter scale needs to be a power of 2. If you want finer grained control of the model sizes # then you should explicitly set base_embed_dim, base_num_query_heads, base_num_kv_heads, @@ -154,10 +155,6 @@ megablox: true sparse_matmul: true capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default use_turbo_grouped_gemm: false # requires sparse_matmul=true and JAX_ENABLE_X64=1 -# Tunable tiling dimensions used for Megablox -tile_batch_seq: 512 -tile_activation_dim: 1024 -tile_weight_dim: 1024 # How the expert axis is used to shard attention weights and activations # "fsdp" (ep acts as fsdp parallelism) @@ -246,6 +243,8 @@ param_scan_axis: 1 # The attention parameter dictates the specific algorithm/methodology used to compute the attention scores # The attention_type parameter determines the variants of attention, e.g. global or local_sliding # moved to model_base.yaml +attention_bias: False +attention_sink: False # MLA parameters # moved to model_base.yaml @@ -271,12 +270,6 @@ local_checkpoint_directory: "" # It should be a positive number when and only when `enable_emergency_checkpoint` is True. local_checkpoint_period: 0 -# Whether to use emergency checkpoint with the replicator service. -use_replicator_service: false - -# The interval to backup local checkpoints to the persistent storage. -replicator_backup_interval_minutes: 0 - # Jax cache directory jax_cache_dir: "~/jax_cache" @@ -289,6 +282,7 @@ logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], ['activation_length', ['sequence', 'context', 'expert']], @@ -308,7 +302,7 @@ logical_axis_rules: [ ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']], + ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], ['activation_vocab', 'tensor_sequence'], ['activation_vocab', ['sequence','context']], @@ -317,6 +311,7 @@ logical_axis_rules: [ ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', ['sequence']], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], @@ -334,10 +329,12 @@ logical_axis_rules: [ ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'expert']], + ["q_lora_up_proj",[]], ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']], + ["kv_lora_up_proj",[]], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['kv', []], @@ -354,6 +351,8 @@ logical_axis_rules: [ ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], + ['dense_layers', []], + ['moe_layers', []], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] @@ -412,7 +411,6 @@ train_image_column: 'image' eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" eval_image_column: 'image' packing: true -max_segments: 32 num_epoch: 1 # only grain and tfds pipeline supports num_epoch > 1 # direct preference optimization (DPO) @@ -449,6 +447,7 @@ hf_access_token: '' # For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline grain_train_files: '' grain_eval_files: '' +grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data. grain_file_type: 'arrayrecord' # arrayrecord or parquet grain_worker_count: 1 grain_worker_count_eval: 1 @@ -462,6 +461,7 @@ log_period: 100 # Flushes Tensorboard jax_distributed_initialization_timeout: 300 # This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py # Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers # only to the jax coordination service. +jax_distributed_heartbeat_timeout_seconds: 300 # How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores. jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization. skip_jax_distributed_system: false # If True we will not initialize the jax distributed system. # Currently the jax distributed is needed on cloud TPUs for async checkpointing. @@ -583,9 +583,6 @@ report_performance_metric_for_gcp_monitoring: false enable_tensorboard: true enable_wandb: false -wandb_project: "" -wandb_exp_name: "" -wandb_save_dir: "" # Vertex AI Tensorboard Configurations - https://github.com/google/maxtext/tree/main/getting_started/Use_Vertex_AI_Tensorboard.md # Set to True for GCE, False if running via XPK @@ -733,3 +730,90 @@ projector_dropout_for_vit: 0.0 # Subslice shape in the form of "x,y,z" when using pathways (single controller). # Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium. subslice_shape: "" + +# tile +wi_tile_dlhs_batch_seq: 512 +wi_tile_dlhs_embed_dim: 1024 +wi_tile_dlhs_mlp_dim: 1024 +wi_tile_drhs_batch_seq: 512 +wi_tile_drhs_embed_dim: 1024 +wi_tile_drhs_mlp_dim: 1024 +wi_tile_fwd_batch_seq: 512 +wi_tile_fwd_embed_dim: 1024 +wi_tile_fwd_mlp_dim: 1024 +wo_tile_dlhs_batch_seq: 512 +wo_tile_dlhs_embed_dim: 1024 +wo_tile_dlhs_mlp_dim: 1024 +wo_tile_drhs_batch_seq: 512 +wo_tile_drhs_embed_dim: 1024 +wo_tile_drhs_mlp_dim: 1024 +wo_tile_fwd_batch_seq: 512 +wo_tile_fwd_embed_dim: 1024 +wo_tile_fwd_mlp_dim: 1024 + +chat_template_path: "" +context_parallel_strategy: all_gather +conv_stride_for_vit: 14 +cost_estimate_flops_bwd: -1 +cost_estimate_flops_fwd: -1 +debug.rl: False +deepstack_visual_indexes_for_vit: [] +dq_reduction_steps: 0 +enable_multi_tier_checkpointing: False +enable_nnx: False +enable_rampup_batch_size: False +float32_weight_sum: True +fsdp_shard_on_exp: False +gdn_chunk_size: 64 +gdn_conv_kernel_dim: 4 +gdn_key_head_dim: 128 +gdn_num_key_heads: 16 +gdn_num_value_heads: 32 +gdn_value_head_dim: 128 +generate_padding_batch_eval: False +generate_padding_batch_train: False +global_rampup_samples: 500 +grad_dtype: float32 +grain_data_source_max_workers: 16 +grain_num_threads: 16 +grain_num_threads_eval: 16 +grain_per_worker_buffer_size: 1 +grain_per_worker_buffer_size_eval: 1 +grain_prefetch_buffer_size: 500 +grain_prefetch_buffer_size_eval: 500 +hide_profiler_step_metric: False +max_num_checkpoints_to_keep: None +max_num_images_per_example: -1 +max_segments_per_seq: 32 +mlp_activations_limit: -1.0 +mlp_bias: False +moba: False +moba_chunk_size: 1024 +moba_topk: 8 +moe_fsdp_use_two_stage_all_gather: False +mtc_data_parallelism: 0 +multi_tier_checkpointing_backup_interval_minutes: 0 +num_position_embeddings_for_vit: 1024 +num_vocab_tiling: 1 +out_hidden_size_for_vit: 512 +partial_rotary_factor: 1.0 +per_device_batch_size_increment: 2.0 +per_device_batch_size_start: 4.0 +posemb_type_for_vit: learn +rope_attention_scaling: False +rope_interleave: True +rope_linear_scaling_factor: 1.0 +rope_truncate: True +save_checkpoint_on_completion: True +shard_mode: auto +shard_optimizer_over_data: False +spatial_merge_size_for_vit: 2 +temporal_patch_size_for_vit: 2 +use_batch_split_schedule: False +use_custom_sort_vjp: True +use_max_logit_estimate: -1 +use_qk_norm_in_gdn: True +use_ring_of_experts: False +use_tokamax_gmm: False +use_tokamax_splash: False +use_truncation: True diff --git a/primus/modules/trainer/maxtext/pre_trainer.py b/primus/modules/trainer/maxtext/pre_trainer.py index 48443a2cb..cf2146b61 100644 --- a/primus/modules/trainer/maxtext/pre_trainer.py +++ b/primus/modules/trainer/maxtext/pre_trainer.py @@ -22,6 +22,7 @@ def __init__(self, *args, **kwargs): self.patch_max_utils() self.patch_checkpoint() self.patch_input_pipeline() + self.patch_config_types() self.patch_layers() self.primus_cfg = kwargs.pop("primus_config", None) @@ -59,7 +60,7 @@ def prepare_model_overrides(self, override_args: Dict[str, Any]): """ Monkey patch maxtext cli args to override model args dynamically. Supports nested overrides like: - {"model": {"num_experts": 16, "base_num_decoder_layers": 4}} + {"override_model": {"num_experts": 16, "base_num_decoder_layers": 4}} All override keys MUST be under the "model" key. """ @@ -70,14 +71,14 @@ def prepare_model_overrides(self, override_args: Dict[str, Any]): warning_rank_0(f"MaxText Pre-Trainer: Applying override_args: {override_args}") - # --- Step 1. Flatten any nested dict under 'model' + # --- Step 1. Flatten any nested dict under 'override_model' flat_overrides = {} for k, v in override_args.items(): - if k != "model": - raise ValueError(f"Only the 'model' key is supported for overrides, found: {k}") + if k != "override_model": + raise ValueError(f"Only the 'override_model' key is supported for overrides, found: {k}") if not isinstance(v, dict): raise ValueError( - f"MaxText Pre-Trainer: The value for 'model' must be a dict, got {type(v).__name__}." + f"MaxText Pre-Trainer: The value for 'override_model' must be a dict, got {type(v).__name__}." ) for subk, subv in v.items(): if isinstance(subv, dict): @@ -121,22 +122,37 @@ def patch_max_utils(self): import MaxText.max_utils as orig_max_utils from primus.backends.maxtext.max_utils import ( + initialize_jax_for_cpu, + initialize_jax_for_gpu, + initialize_jax_for_tpu_with_emergency_checkpointing, + maybe_initialize_jax_distributed_system, print_system_information, save_device_information, ) + orig_max_utils.maybe_initialize_jax_distributed_system = maybe_initialize_jax_distributed_system + orig_max_utils.initialize_jax_for_gpu = initialize_jax_for_gpu + orig_max_utils.initialize_jax_for_cpu = initialize_jax_for_cpu + orig_max_utils.initialize_jax_for_tpu_with_emergency_checkpointing = ( + initialize_jax_for_tpu_with_emergency_checkpointing + ) orig_max_utils.print_system_information = print_system_information orig_max_utils.save_device_information = save_device_information warning_rank_0("MaxText Pre-Trainer: patch max_utils successfully.") def patch_checkpoint(self): import MaxText.checkpointing as orig_checkpointing + import MaxText.train_utils as orig_train_utils from primus.backends.maxtext.checkpointing import ( create_orbax_checkpoint_manager, + load_state_if_possible, ) + from primus.backends.maxtext.train_utils import create_training_tools + orig_checkpointing.load_state_if_possible = load_state_if_possible orig_checkpointing.create_orbax_checkpoint_manager = create_orbax_checkpoint_manager + orig_train_utils.create_training_tools = create_training_tools warning_rank_0("MaxText Pre-Trainer: patch checkpointing successfully.") def patch_wandb(self): @@ -177,6 +193,14 @@ def patch_input_pipeline(self): warning_rank_0("MaxText Pre-Trainer: patch _hf_data_processing successfully.") + def patch_config_types(self): + import MaxText.configs.types as orig_config_types + + from primus.backends.maxtext.configs.types import PrimusMaxTextConfig + + orig_config_types.MaxTextConfig = PrimusMaxTextConfig + warning_rank_0("MaxText Pre-Trainer: patch config types successfully.") + def patch_layers(self): def patch_quantization(): import MaxText.layers.quantizations as orig_quantizations @@ -191,18 +215,14 @@ def patch_quantization(): patch_quantization() def patch_attn(): - import MaxText.layers.attention_mla as orig_attention_mla import MaxText.layers.attention_op as orig_attention_op import MaxText.layers.attentions as orig_attentions from primus.backends.maxtext.layers.attention_op import PrimusAttentionOp - from primus.backends.maxtext.layers.attentions import PrimusAttention orig_attention_op.AttentionOp = PrimusAttentionOp orig_attentions.AttentionOp = PrimusAttentionOp - orig_attentions.Attention = PrimusAttention - orig_attention_mla.Attention = PrimusAttention warning_rank_0("MaxText Pre-Trainer: patch Attention successfully.") patch_attn() @@ -216,3 +236,25 @@ def patch_moe(): warning_rank_0("MaxText Pre-Trainer: patch RoutedMoE successfully.") patch_moe() + + def patch_decoder_layer(): + import MaxText.layers.gemma as orig_gemma + import MaxText.layers.gemma2 as orig_gemma2 + import MaxText.layers.llama2 as orig_llama2 + import MaxText.layers.mistral as orig_mistral + import MaxText.layers.mixtral as orig_mixtral + + from primus.backends.maxtext.layers.gemma import PrimusGemmaDecoderLayer + from primus.backends.maxtext.layers.gemma2 import PrimusGemma2DecoderLayer + from primus.backends.maxtext.layers.llama2 import PrimusLlamaDecoderLayer + from primus.backends.maxtext.layers.mistral import PrimusMistralDecoderLayer + from primus.backends.maxtext.layers.mixtral import PrimusMixtralDecoderLayer + + orig_gemma.GemmaDecoderLayer = PrimusGemmaDecoderLayer + orig_gemma2.Gemma2DecoderLayer = PrimusGemma2DecoderLayer + orig_llama2.LlamaDecoderLayer = PrimusLlamaDecoderLayer + orig_mistral.MistralDecoderLayer = PrimusMistralDecoderLayer + orig_mixtral.MixtralDecoderLayer = PrimusMixtralDecoderLayer + warning_rank_0("MaxText Pre-Trainer: patch decoder layer successfully.") + + patch_decoder_layer() diff --git a/primus/pretrain.py b/primus/pretrain.py index f1136a458..7f4671096 100644 --- a/primus/pretrain.py +++ b/primus/pretrain.py @@ -112,7 +112,10 @@ def setup_backend_path(framework: str, backend_path=None, verbose: bool = True): } mapped_name = fallback_name_map.get(framework, framework) default_path = Path(__file__).resolve().parent.parent / "third_party" / mapped_name - candidate_paths.append(default_path) + if framework == "maxtext" and (default_path / "src").exists(): + default_path = default_path / "src" + candidate_paths.insert(0, str(default_path)) + print(f"[Primus] candidate_paths: {candidate_paths}") # Normalize & deduplicate candidate_paths = list(dict.fromkeys(os.path.normpath(os.path.abspath(p)) for p in candidate_paths)) diff --git a/requirements-jax.txt b/requirements-jax.txt index aebef58e3..1a65a0ff1 100644 --- a/requirements-jax.txt +++ b/requirements-jax.txt @@ -1,12 +1,3 @@ loguru wandb -expecttest pre-commit -nltk -matplotlib -markdown2 -weasyprint -tyro -blobfile -mlflow -pyrsmi diff --git a/runner/.primus.yaml b/runner/.primus.yaml index 989a7d6d1..5f1599302 100644 --- a/runner/.primus.yaml +++ b/runner/.primus.yaml @@ -47,7 +47,7 @@ container: device: - "/dev/kfd" - "/dev/dri" - # - "/dev/infiniband" + - "/dev/infiniband" # Linux capabilities (each passed as --cap-add) # NOTE: Do not modify these capabilities - they are required for proper container operation diff --git a/runner/helpers/hooks/03_enable_ainic.sh b/runner/helpers/hooks/03_enable_ainic.sh index 82b198ad0..f79bbb8e0 100755 --- a/runner/helpers/hooks/03_enable_ainic.sh +++ b/runner/helpers/hooks/03_enable_ainic.sh @@ -41,8 +41,9 @@ NCCL_IB_QPS_PER_CONNECTION="${NCCL_IB_QPS_PER_CONNECTION:-1}" # LD_LIBRARY_PATH: prepend AINIC/RCCL/MPI paths while preserving existing. _ld_base="/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/install/lib" -LD_LIBRARY_PATH="${_ld_base}${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" - +# Need to append AINIC/RCCL/MPI paths to the existing LD_LIBRARY_PATH. Otherwise, +# JAX MaxText will not find the appropriate ROCm libraries. +LD_LIBRARY_PATH="${LD_LIBRARY_PATH:+${LD_LIBRARY_PATH}:}${_ld_base}" LOG_INFO_RANK0 "Using AINIC" LOG_INFO_RANK0 "RCCL_HOME_DIR: ${RCCL_HOME_DIR}" LOG_INFO_RANK0 "ANP_HOME_DIR: ${ANP_HOME_DIR}" diff --git a/runner/helpers/hooks/train/pretrain/maxtext/prepare.py b/runner/helpers/hooks/train/pretrain/maxtext/prepare.py index 917f7cf8f..75bbe8d73 100644 --- a/runner/helpers/hooks/train/pretrain/maxtext/prepare.py +++ b/runner/helpers/hooks/train/pretrain/maxtext/prepare.py @@ -141,7 +141,6 @@ def install_maxtext_dependencies() -> None: cmd = ( "apt install iproute2 -y && " "apt install -y " - 'linux-headers-"$(uname -r)" ' "libelf-dev " "gcc make libtool autoconf " "librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils perftest ethtool " @@ -237,14 +236,18 @@ def main(): print(f"env.DUMP_HLO_DIR={dump_hlo_dir}") print(f"env.DUMP_HLO={dump_hlo}") print("env.NVTE_ALLOW_NONDETERMINISTIC_ALGO=1") - print("env.XLA_PYTHON_CLIENT_MEM_FRACTION=.97") + # set XLA_PYTHON_CLIENT_MEM_FRACTION to 0.93 + # to avoid HSA_STATUS_ERROR_OUT_OF_RESOURCES error during multi-node training + print("env.XLA_PYTHON_CLIENT_MEM_FRACTION=.93") print("env.NVTE_USE_HIPBLASLT=1") - xla_flags = "--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_graph_level=0 --xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=False --xla_gpu_enable_cublaslt=True --xla_gpu_autotune_level=0 --xla_gpu_enable_all_gather_combine_by_dim=FALSE" + xla_flags = "--xla_gpu_memory_limit_slop_factor=95 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_command_buffer='' --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_cublaslt=true --xla_gpu_autotune_level=0 --xla_gpu_enable_all_gather_combine_by_dim=false" if dump_hlo == "1": xla_flags += f" --xla_dump_to={dump_hlo_dir}" log_info(f"XLA HLO dumping enabled, output directory: {dump_hlo_dir}") print(f"env.XLA_FLAGS={xla_flags}") + # set TF_CPP_MIN_LOG_LEVEL=2 to suppress the error messages at the end of JAX/MaxText training + print(f"env.TF_CPP_MIN_LOG_LEVEL=2") # AMD GPU optimizations print("env.HIP_FORCE_DEV_KERNARG=1") diff --git a/runner/primus-cli-direct.sh b/runner/primus-cli-direct.sh index 1dddcb551..4eb6c92c4 100755 --- a/runner/primus-cli-direct.sh +++ b/runner/primus-cli-direct.sh @@ -337,9 +337,45 @@ mkdir -p "$(dirname "${direct_config[log_file]:-}")" ############################################################################### # STEP 5: Install dependencies ############################################################################### +# Detect the backend framework from the experiment YAML (--config in PRIMUS_ARGS) +# so we can install the correct requirements file: +# maxtext -> requirements-jax.txt +# others -> requirements.txt +_detect_framework() { + local cfg_path="" + local args=("${primus_args[@]}") + for ((i=0; i<${#args[@]}; i++)); do + if [[ "${args[$i]}" == "--config" && -n "${args[$((i+1))]:-}" ]]; then + cfg_path="${args[$((i+1))]}" + break + fi + done + if [[ -z "$cfg_path" || ! -f "$cfg_path" ]]; then + echo "" + return + fi + python3 -c " +import yaml, sys +try: + cfg = yaml.safe_load(open('$cfg_path')) + print(cfg.get('modules',{}).get('pre_trainer',{}).get('framework','')) +except Exception: + print('') +" 2>/dev/null +} + +DETECTED_FRAMEWORK="$(_detect_framework)" +LOG_INFO_RANK0 "[direct] Detected framework: ${DETECTED_FRAMEWORK:-unknown}" + # Skip pip install in dry-run mode if [[ "$DRY_RUN_MODE" != "1" ]]; then - pip install -qq -r requirements.txt + if [[ "$DETECTED_FRAMEWORK" == "maxtext" ]]; then + LOG_INFO_RANK0 "[direct] Installing JAX dependencies (requirements-jax.txt)" + pip install -qq -r requirements-jax.txt + else + LOG_INFO_RANK0 "[direct] Installing PyTorch dependencies (requirements.txt)" + pip install -qq -r requirements.txt + fi fi ############################################################################### diff --git a/runner/use_ainic.yaml b/runner/use_ainic.yaml new file mode 100644 index 000000000..83c8b9677 --- /dev/null +++ b/runner/use_ainic.yaml @@ -0,0 +1,93 @@ +# Primus CLI System Default Configuration +# This file provides system-wide default settings for Primus CLI +# +# Priority: CLI args > User config (~/.primus.yaml) > System defaults (this file) + +# Main settings (apply to all modes) +main: + debug: false + dry_run: false + +# Slurm-specific settings +slurm: + debug: false + dry_run: false + # partition: "" + nodes: 1 + gpus_per_node: 8 + time: "4:00:00" + +# Container-specific settings +container: + debug: false + dry_run: false + + # Docker/Podman runtime options + # All keys directly map to CLI arguments (--key value) + options: + # Container image + image: "rocm/jax-training:maxtext-v26.1" + + # Single-value options + ipc: "host" + network: "host" + # cpus: "96" + # memory: "256G" + name: "primus-training" + privileged: "true" + security-opt: "seccomp=unconfined" + group-add: "video" + + # Cumulative options (can be specified multiple times via CLI) + # Device access (each passed as --device) + # NOTE: Do not modify these device paths - they are required for ROCm/GPU access + # /dev/kfd - Kernel Fusion Driver (ROCm core) + # /dev/dri - Direct Rendering Infrastructure (GPU access) + # /dev/infiniband - InfiniBand network device (multi-node communication) + device: + - "/dev/kfd" + - "/dev/dri" + - "/dev/infiniband" + + # Linux capabilities (each passed as --cap-add) + # NOTE: Do not modify these capabilities - they are required for proper container operation + # SYS_PTRACE - Required for debugging and profiling tools + # CAP_SYS_ADMIN - Required for system administration operations + cap-add: + - "SYS_PTRACE" + - "CAP_SYS_ADMIN" + + # Volume mounts (each passed as --volume) + volume: [] + # volume: + # - "/data:/data" + # - "/output:/output" + # - "/workspace/Primus" + # - "/model_weights:/model_weights:ro" + + # Environment variables (each passed as --env KEY=VALUE) + env: + # If using AINIC, set the environment variables for AINIC + # make sure NCCL_IB_GID_INDEX value is set appropriately + - "USING_AINIC=1" + - "NCCL_PXN_DISABLE=0" + - "NCCL_IB_GID_INDEX=1" +# Direct mode settings +direct: + debug: false + + # Distributed training parameters + gpus_per_node: 8 + master_port: 1234 + nnodes: 1 + master_addr: "localhost" + + # Direct mode specific options + run_mode: "torchrun" + script: "primus/cli/main.py" + numa: "auto" + log_file: "" + + # Default patch scripts and env vars + patch: [] + env: [] diff --git a/tests/trainer/test_maxtext_trainer.py b/tests/trainer/test_maxtext_trainer.py index 6e04344e0..945ea6aef 100644 --- a/tests/trainer/test_maxtext_trainer.py +++ b/tests/trainer/test_maxtext_trainer.py @@ -88,7 +88,7 @@ def test_llama3_8B_BF16(self): "llama3_8B-BF16", exp_path="examples/maxtext/configs/MI300X/llama3_8B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -101,7 +101,7 @@ def test_llama3_8B_FP8(self): "llama3_8B-FP8", exp_path="examples/maxtext/configs/MI300X/llama3_8B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -117,7 +117,7 @@ def test_llama3_70B_BF16(self): "llama3_70B-BF16", exp_path="examples/maxtext/configs/MI300X/llama3_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -131,7 +131,7 @@ def test_llama3_70B_FP8(self): "llama3_70B-FP8", exp_path="examples/maxtext/configs/MI300X/llama3_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -147,7 +147,7 @@ def test_llama3_3_70B_BF16(self): "llama3_3_70B-BF16", exp_path="examples/maxtext/configs/MI300X/llama3.3_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -161,7 +161,7 @@ def test_llama3_3_70B_FP8(self): "llama3_3_70B-FP8", exp_path="examples/maxtext/configs/MI300X/llama3.3_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -177,7 +177,7 @@ def test_llama2_7B_BF16(self): "llama2_7B-BF16", exp_path="examples/maxtext/configs/MI300X/llama2_7B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -191,7 +191,7 @@ def test_llama2_7B_FP8(self): "llama2_7B-FP8", exp_path="examples/maxtext/configs/MI300X/llama2_7B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -207,7 +207,7 @@ def test_llama2_70B_BF16(self): "llama2_70B-BF16", exp_path="examples/maxtext/configs/MI300X/llama2_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -221,7 +221,7 @@ def test_llama2_70B_FP8(self): "llama2_70B-FP8", exp_path="examples/maxtext/configs/MI300X/llama2_70B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -237,7 +237,7 @@ def test_mixtral_8x7B_BF16(self): "mixtral_8x7B-BF16", exp_path="examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -251,7 +251,7 @@ def test_mixtral_8x7B_FP8(self): "mixtral_8x7B-FP8", exp_path="examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -267,7 +267,7 @@ def test_grok1_BF16(self): "grok1-BF16", exp_path="examples/maxtext/configs/MI300X/grok1-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -281,7 +281,7 @@ def test_grok1_FP8(self): "grok1-FP8", exp_path="examples/maxtext/configs/MI300X/grok1-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -297,7 +297,7 @@ def test_dpsk_v2_16B_BF16(self): "dpsk_v2_16B-BF16", exp_path="examples/maxtext/configs/MI300X/deepseek_v2_16B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", @@ -311,7 +311,7 @@ def test_dpsk_v2_16B_FP8(self): "dpsk_v2_16B-FP8", exp_path="examples/maxtext/configs/MI300X/deepseek_v2_16B-pretrain.yaml", extra_args=[ - "--model.base_num_decoder_layers", + "--override_model.base_num_decoder_layers", "4", "--steps", "3", diff --git a/third_party/maxtext b/third_party/maxtext index 8def32a8a..022dc02eb 160000 --- a/third_party/maxtext +++ b/third_party/maxtext @@ -1 +1 @@ -Subproject commit 8def32a8a5b96fc6267636a8e58abfe4c178e161 +Subproject commit 022dc02eb89057350d2e365f23c8f1f0edb4732d