diff --git a/.github/workflows/gpu-tests.yml b/.github/workflows/gpu-tests.yml index 8d9bfed..a640de5 100644 --- a/.github/workflows/gpu-tests.yml +++ b/.github/workflows/gpu-tests.yml @@ -33,19 +33,26 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade uv python -m uv pip install -U pytest "jax[cuda12]" - python -m uv pip install nvidia-cusolver-cu12==11.7.3.90 - python -m uv pip install nvidia-cublas-cu12 - python -m uv pip install jax-triton triton==3.3.1 - # python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch python -m uv pip install cuequivariance-ops-cu12 cuequivariance-ops-jax-cu12 + + # Add NVIDIA CUDA libraries to LD_LIBRARY_PATH + SITE_PACKAGES=$(python -c "import site; print(' '.join(site.getsitepackages()))") + CUDA_LIB_DIRS=$(find $SITE_PACKAGES -path "*/nvidia/*/lib" -type d 2>/dev/null | tr '\n' ':') + export LD_LIBRARY_PATH="$CUDA_LIB_DIRS$LD_LIBRARY_PATH" + + python -c "import cuequivariance_ops; print('cueop', cuequivariance_ops.__version__)" + python -c "import cuequivariance_ops_jax; print('cueopx', cuequivariance_ops_jax.__version__)" + python -m uv pip install -e ./cuequivariance python -m uv pip install -e ./cuequivariance_jax - - # python -c "import cuequivariance; print('cue', cuequivariance.__version__)" - # python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)" + python -c "import cuequivariance; print('cue', cuequivariance.__version__)" + python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)" - name: Test with pytest run: | - # XLA_PYTHON_CLIENT_PREALLOCATE=false pytest --doctest-modules -x -m "not slow" cuequivariance_jax - echo "skipping tests" + # Set up CUDA library path for tests + SITE_PACKAGES=$(python -c "import site; print(' '.join(site.getsitepackages()))") + CUDA_LIB_DIRS=$(find $SITE_PACKAGES -path "*/nvidia/*/lib" -type d 2>/dev/null | tr '\n' ':') + export LD_LIBRARY_PATH="$CUDA_LIB_DIRS$LD_LIBRARY_PATH" + XLA_PYTHON_CLIENT_PREALLOCATE=false pytest --doctest-modules -x -m "not slow" cuequivariance_jax