File tree Expand file tree Collapse file tree 4 files changed +21
-8
lines changed Expand file tree Collapse file tree 4 files changed +21
-8
lines changed Original file line number Diff line number Diff line change 3535 - name : build maxdiffusion jax nightly image
3636 run : |
3737 bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
38+
39+ build-gpu-image :
40+ runs-on : ["self-hosted", "e2", "cpu"]
41+ steps :
42+ - uses : actions/checkout@v3
43+ - name : Cleanup old docker images
44+ run : docker system prune --all --force
45+ - name : build maxdiffusion jax stable stack gpu image
46+ run : |
47+ bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest DEVICE=gpu
48+ - name : build maxdiffusion jax nightly image
49+ run : |
50+ bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu
Original file line number Diff line number Diff line change @@ -34,13 +34,15 @@ for ARGUMENT in "$@"; do
3434 echo " $KEY " =" $VALUE "
3535done
3636
37+ export DEVICE=" ${DEVICE:- tpu} "
38+
3739if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] ; then
3840 echo " You must set CLOUD_IMAGE_NAME, PROJECT and MODE"
3941 exit 1
4042fi
4143
4244gcloud auth configure-docker us-docker.pkg.dev --quiet
43- bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE
45+ bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE DEVICE= $DEVICE
4446image_date=$( date +%Y-%m-%d)
4547
4648# Upload only dependencies image
Original file line number Diff line number Diff line change @@ -22,8 +22,7 @@ RUN apt-get update && apt-get install -y google-cloud-sdk
2222# Set environment variables for Google Cloud SDK
2323ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"
2424
25- # Upgrade libcusprase to work with Jax
26- RUN apt-get update && apt-get install -y libcusparse-12-3
25+
2726
2827ARG MODE
2928ENV ENV_MODE=$MODE
@@ -46,5 +45,4 @@ RUN ls .
4645RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
4746RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
4847
49-
5048WORKDIR /deps
Original file line number Diff line number Diff line change @@ -55,6 +55,9 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
5555 exit 1
5656fi
5757
58+ # Install dependencies from requirements.txt first
59+ pip3 install -U -r requirements.txt || echo " Failed to install dependencies in the requirements" >&2
60+
5861# Install JAX and JAXlib based on the specified mode
5962if [[ " $MODE " == " stable" || ! -v MODE ]]; then
6063 # Stable mode
@@ -78,7 +81,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
7881 pip3 install " jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
7982 fi
8083 export NVTE_FRAMEWORK=jax
81- pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
84+ pip3 install transformer_engine[jax]==2.1.0
8285 fi
8386
8487elif [[ $MODE == " nightly" ]]; then
106109 exit 1
107110fi
108111
109- # Install dependencies from requirements.txt
110- pip3 install -U -r requirements.txt || echo " Failed to install dependencies in the requirements" >&2
111-
112112# Install maxdiffusion
113113pip3 install -U . || echo " Failed to install maxdiffusion" >&2
You can’t perform that action at this time.
0 commit comments