diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index f12eadc2ce794..7f7ff74959d52 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -27,7 +27,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -112,7 +112,7 @@ jobs: android_nnapi_ep: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Use jdk 17 uses: actions/setup-java@v5 @@ -187,7 +187,7 @@ jobs: name: Android CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Use jdk 17 uses: actions/setup-java@v5 diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index ddf4a52a0ccb0..30f832f67c5ee 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -12,7 +12,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Check out a copy of the repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@2.0.0 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 1db84400c272a..d33e4d923a0bc 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index d8f13d13d3f88..04177b11e9c30 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -15,7 +15,7 @@ jobs: name: "Validation" runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - uses: gradle/actions/wrapper-validation@v5 concurrency: group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} diff --git a/.github/workflows/ios.yml b/.github/workflows/ios.yml index ed572aa339ce9..0d2046b980783 100644 --- a/.github/workflows/ios.yml +++ b/.github/workflows/ios.yml @@ -20,7 +20,7 @@ jobs: runs-on: macos-14 steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5c618dc5787a5..5aaab5f8e1a10 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: name: Optional Lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: misspell # Check spellings as well uses: reviewdog/action-misspell@v1 with: @@ -42,7 +42,7 @@ jobs: contents: read security-events: write steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Setup Python uses: actions/setup-python@v6 with: @@ -87,7 +87,7 @@ jobs: name: Optional Lint C++ runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Update PATH run: | echo "$HOME/.local/bin" >> "$GITHUB_PATH" @@ -116,7 +116,7 @@ jobs: name: Lint JavaScript runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - uses: actions/setup-node@v6 with: node-version: 20 diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 5763b9c39bcc6..2370c631b7a7a 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -49,7 +49,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: recursive diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index e7e3be8c5f9ed..886705471b7de 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -48,7 +48,7 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index 4d9579a746892..af86975ee6cdc 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -28,7 +28,7 @@ jobs: packages: write steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -65,7 +65,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 @@ -122,7 +122,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 @@ -156,7 +156,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 @@ -188,7 +188,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 @@ -222,7 +222,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 @@ -286,7 +286,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -363,7 +363,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -430,7 +430,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -505,7 +505,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 47b7c1ba7e889..0e26576829e94 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -48,7 +48,7 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 # --- Build the Docker image needed for testing --- - name: Build Docker Image for Testing diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 8ba87bc1f731c..e545406d8d20f 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -76,7 +76,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' @@ -124,7 +124,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index 8e1d0264496f6..329584c68d7d1 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -75,7 +75,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index 7ca330742f69b..abe627f4ff7bc 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -24,7 +24,7 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Setup Python uses: actions/setup-python@v6 with: diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index d9fb72271967f..25b7899584bbf 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate C/C++ API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Install doxygen and dependencies run: | sudo apt update diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index dd55bbd917337..34b9c1af9552f 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -24,7 +24,7 @@ jobs: env: DOCFXVERSION: 2.62.2 steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Install DocFX run: | dotnet tool update -g docfx diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 81defeae518a3..656d0627ed17d 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Java docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Set up JDK 11 uses: actions/setup-java@v5 with: diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index 9da78d7d9ed9c..e71d3b3c57a4b 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate JS API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Setup Node.js uses: actions/setup-node@v6 with: diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index a73b62eba6050..983d3d478a49d 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Objective-C API docs runs-on: macos-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index e35e6a04adbef..389d1683fb1ff 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate Python API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Install tools run: | sudo apt-get update diff --git a/.github/workflows/react_native.yml b/.github/workflows/react_native.yml index 4a56dfbd35406..343186b1aec8c 100644 --- a/.github/workflows/react_native.yml +++ b/.github/workflows/react_native.yml @@ -20,7 +20,7 @@ jobs: aar_path: ${{ runner.temp }}/.artifacts steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -75,7 +75,7 @@ jobs: run: echo "ANDROID_AVD_HOME=${{ runner.temp }}/android-avd" >> $GITHUB_ENV - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Use Python 3.12 uses: actions/setup-python@v6 @@ -175,7 +175,7 @@ jobs: timeout-minutes: 120 steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Use Xcode 15.3.0 run: sudo xcode-select --switch /Applications/Xcode_15.3.0.app/Contents/Developer @@ -218,7 +218,7 @@ jobs: timeout-minutes: 90 steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Download iOS pod artifact uses: actions/download-artifact@v6 diff --git a/.github/workflows/reusable_linux_build.yml b/.github/workflows/reusable_linux_build.yml index f0da87647b8b0..795e35b06bfb0 100644 --- a/.github/workflows/reusable_linux_build.yml +++ b/.github/workflows/reusable_linux_build.yml @@ -75,7 +75,7 @@ jobs: id-token: write steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Set up Python ${{ inputs.python_version }} if: inputs.architecture != 'arm64' diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml index 6ae25ccc0bf3e..016feab5e0d94 100644 --- a/.github/workflows/web.yml +++ b/.github/workflows/web.yml @@ -22,7 +22,7 @@ jobs: commit_sha: ${{ steps.extract_commit.outputs.commit_sha }} steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: true diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index c16ce6eb222eb..eee98332056f6 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_build_x64_asan.yml b/.github/workflows/windows_build_x64_asan.yml index ac5f08717155f..05fd4acd4de9a 100644 --- a/.github/workflows/windows_build_x64_asan.yml +++ b/.github/workflows/windows_build_x64_asan.yml @@ -19,7 +19,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 5d6e9b1da31a2..fd5b65eb039a3 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -21,7 +21,7 @@ jobs: name: Windows GPU CUDA CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' @@ -152,7 +152,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index 0abf6b650f986..e8ee7751348b4 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -27,7 +27,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 with: fetch-depth: 0 # Fetch all history for all tags and branches submodules: 'none' diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index 537ff1fb00071..b608c0879aa45 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml index f6176164354bb..4f0b50e65df6e 100644 --- a/.github/workflows/windows_qnn_x64.yml +++ b/.github/workflows/windows_qnn_x64.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Setup Python uses: actions/setup-python@v6 diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index 4a564a3b1cb36..229efb01f0018 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -21,7 +21,7 @@ jobs: name: Windows GPU TensorRT CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' @@ -157,7 +157,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index f729cda5ea576..899a8b66eac7a 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -34,7 +34,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none @@ -156,7 +156,7 @@ jobs: timeout-minutes: 300 steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none @@ -209,7 +209,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index 385d03c1a6705..d62c7130e0ebb 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index ee045b70b6efa..a2991bb0f1131 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index 25dfc41e6922c..bb6c5035b0dce 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index e738db262f3a2..4378231338673 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index 5672e4043c624..b453cd570ac05 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index 381d9dda5cd42..d20778d56f60b 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm new file mode 100644 index 0000000000000..aca8c3feaff71 --- /dev/null +++ b/dockerfiles/Dockerfile.rocm @@ -0,0 +1,24 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to run ONNXRuntime with ROCm integration +#-------------------------------------------------------------------------- + +FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 + +ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime +ARG ONNXRUNTIME_BRANCH=main + +WORKDIR /code + +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} + +# Prepare onnxruntime repository & build onnxruntime +RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ + /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ + cd onnxruntime &&\ + /bin/sh ./build.sh --allow_running_as_root --config Release --build_wheel --update --build --parallel --cmake_extra_defines\ + ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm &&\ + pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ + cd .. diff --git a/dockerfiles/README.md b/dockerfiles/README.md index 88c542b63ccd2..4c69098103edd 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -1,8 +1,9 @@ # Dockerfiles **Execution Providers** - CPU: [Dockerfile](Dockerfile.source), [Instructions](#cpu) -- CUDA: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) +- CUDA/cuDNN: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) - MIGraphX: [Dockerfile](Dockerfile.migraphx), [Instructions](#migraphx) +- ROCm: [Dockerfile](Dockerfile.rocm), [Instructions](#rocm) - OpenVINO: [Dockerfile](Dockerfile.openvino), [Instructions](#openvino) - TensorRT: [Dockerfile](Dockerfile.tensorrt), [Instructions](#tensorrt) - VitisAI: [Dockerfile](Dockerfile.vitisai) @@ -303,3 +304,17 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-migraphx ``` + + ## ROCm +**Ubuntu 22.04, ROCm6.2.3** + +1. Build the docker image from the Dockerfile in this repository. + ``` + docker build -t onnxruntime-rocm -f Dockerfile.rocm . + ``` + +2. Run the Docker image + + ``` + docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-rocm + ``` diff --git a/dockerfiles/scripts/install_rocm_deps.sh b/dockerfiles/scripts/install_rocm_deps.sh new file mode 100644 index 0000000000000..fd445be87479b --- /dev/null +++ b/dockerfiles/scripts/install_rocm_deps.sh @@ -0,0 +1,84 @@ +#!/bin/bash +prefix=/opt/rocm +DEBIAN_FRONTEND=noninteractive +apt-get update && apt-get install -y --no-install-recommends \ + wget \ + zip \ + ca-certificates \ + build-essential \ + curl \ + libcurl4-openssl-dev \ + libssl-dev \ + python3-dev + +# rocm-cmake +rocm_cmake_version=4.5.2 +wget --quiet https://github.com/RadeonOpenCompute/rocm-cmake/archive/refs/tags/rocm-${rocm_cmake_version}.tar.gz +tar -xzvf rocm-${rocm_cmake_version}.tar.gz +rm rocm-${rocm_cmake_version}.tar.gz +cd rocm-cmake-rocm-${rocm_cmake_version} +mkdir build +cd build +cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rocm-cmake-rocm-${rocm_cmake_version} + +# rccl +rccl_version=4.5.2 +wget --quiet https://github.com/ROCmSoftwarePlatform/rccl/archive/refs/tags/rocm-${rccl_version}.tar.gz +tar -xzvf rocm-${rccl_version}.tar.gz +rm rocm-${rccl_version}.tar.gz +cd rccl-rocm-${rccl_version} +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rccl-rocm-${rccl_version} + +#rocrand +rocrand_version=4.5.2 +wget --quiet https://github.com/ROCmSoftwarePlatform/rocRAND/archive/refs/tags/rocm-${rocrand_version}.tar.gz +tar -xzvf rocm-${rocrand_version}.tar.gz +rm rocm-${rocrand_version}.tar.gz +cd rocRAND-rocm-${rocrand_version} +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rocRAND-rocm-${rocrand_version} + +#hipcub +hipcub_version=4.5.2 +wget --quiet https://github.com/ROCmSoftwarePlatform/hipCUB/archive/refs/tags/rocm-${hipcub_version}.tar.gz +tar -xzvf rocm-${hipcub_version}.tar.gz +rm rocm-${hipcub_version}.tar.gz +cd hipCUB-rocm-${hipcub_version} +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make package +make install +cd ../.. +rm -rf hipCUB-rocm-${hipcub_version} + +#rocprim +rocprim_version=4.5.2 +wget --quiet https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/refs/tags/rocm-${rocprim_version}.tar.gz +tar -xzvf rocm-${rocprim_version}.tar.gz +rm rocm-${rocprim_version}.tar.gz +cd rocPRIM-rocm-${rocprim_version} +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rocPRIM-rocm-${rocprim_version} + diff --git a/js/package-lock.json b/js/package-lock.json index 0fca515b61238..1e9f5cb29fe6c 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4,7 +4,6 @@ "requires": true, "packages": { "": { - "name": "js", "license": "MIT", "devDependencies": { "@eslint/compat": "^1.4.0", @@ -3231,27 +3230,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/glob": { - "version": "10.5.0", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", - "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", - "dev": true, - "license": "ISC", - "dependencies": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "bin": { - "glob": "dist/esm/bin.mjs" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -3264,32 +3242,6 @@ "node": ">=10.13.0" } }, - "node_modules/glob/node_modules/brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "balanced-match": "^1.0.0" - } - }, - "node_modules/glob/node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "license": "ISC", - "dependencies": { - "brace-expansion": "^2.0.1" - }, - "engines": { - "node": ">=16 || 14 >=14.17" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/global-agent": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz", @@ -4359,6 +4311,43 @@ "balanced-match": "^1.0.0" } }, + "node_modules/mocha/node_modules/glob": { + "version": "10.4.5", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", + "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", + "dev": true, + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/mocha/node_modules/glob/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/mocha/node_modules/minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", @@ -8089,40 +8078,6 @@ "get-intrinsic": "^1.2.6" } }, - "glob": { - "version": "10.5.0", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", - "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", - "dev": true, - "requires": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "dependencies": { - "brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", - "dev": true, - "requires": { - "balanced-match": "^1.0.0" - } - }, - "minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "requires": { - "brace-expansion": "^2.0.1" - } - } - } - }, "glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -8817,6 +8772,31 @@ "balanced-match": "^1.0.0" } }, + "glob": { + "version": "10.4.5", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", + "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", + "dev": true, + "requires": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "dependencies": { + "minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "requires": { + "brace-expansion": "^2.0.1" + } + } + } + }, "minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index de8d631362db7..e6ed2bdb9e17b 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -33,7 +33,6 @@ "version": "1.24.0", "license": "MIT", "devDependencies": { - "globby": "^15.0.0", "typedoc": "^0.25.7" } }, @@ -62,15 +61,15 @@ } }, "node_modules/@babel/code-frame": { - "version": "7.27.1", - "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", - "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", + "version": "7.26.2", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", + "integrity": "sha512-RJlIHRueQgwWitWgF8OdFYGZX328Ax5BCemNGlqHfplnRT9ESi8JkFlvaVYbS+UubVY6dpv87Fs2u5M29iNFVQ==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-validator-identifier": "^7.27.1", + "@babel/helper-validator-identifier": "^7.25.9", "js-tokens": "^4.0.0", - "picocolors": "^1.1.1" + "picocolors": "^1.0.0" }, "engines": { "node": ">=6.9.0" @@ -411,9 +410,9 @@ } }, "node_modules/@babel/helper-string-parser": { - "version": "7.27.1", - "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", - "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", + "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", "dev": true, "license": "MIT", "engines": { @@ -421,9 +420,9 @@ } }, "node_modules/@babel/helper-validator-identifier": { - "version": "7.28.5", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", - "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", + "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", "dev": true, "license": "MIT", "engines": { @@ -456,27 +455,27 @@ } }, "node_modules/@babel/helpers": { - "version": "7.28.4", - "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.28.4.tgz", - "integrity": "sha512-HFN59MmQXGHVyYadKLVumYsA9dBFun/ldYxipEjzA4196jpLZd8UjEEBLkbEkvfYreDqJhZxYAWFPtrfhNpj4w==", + "version": "7.25.6", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.25.6.tgz", + "integrity": "sha512-Xg0tn4HcfTijTwfDwYlvVCl43V6h4KyVVX2aEm4qdO/PC6L2YvzLHFdmxhoeSA3eslcE6+ZVXHgWwopXYLNq4Q==", "dev": true, "license": "MIT", "dependencies": { - "@babel/template": "^7.27.2", - "@babel/types": "^7.28.4" + "@babel/template": "^7.25.0", + "@babel/types": "^7.25.6" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/parser": { - "version": "7.28.5", - "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.5.tgz", - "integrity": "sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==", + "version": "7.26.9", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.9.tgz", + "integrity": "sha512-81NWa1njQblgZbQHxWHpxxCzNsa3ZwvFqpUg7P+NNUU6f3UU2jBEg4OlF/J6rl8+PQGh1q6/zWScd001YwcA5A==", "dev": true, "license": "MIT", "dependencies": { - "@babel/types": "^7.28.5" + "@babel/types": "^7.26.9" }, "bin": { "parser": "bin/babel-parser.js" @@ -2115,25 +2114,35 @@ } }, "node_modules/@babel/runtime": { - "version": "7.28.4", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz", - "integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==", + "version": "7.25.6", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.6.tgz", + "integrity": "sha512-VBj9MYyDb9tuLq7yzqjgzt6Q+IBQLrGZfdjOekyEirZPHxXWoTSGUTMrpsfi58Up73d13NfYLv8HT9vmznjzhQ==", "dev": true, "license": "MIT", + "dependencies": { + "regenerator-runtime": "^0.14.0" + }, "engines": { "node": ">=6.9.0" } }, + "node_modules/@babel/runtime/node_modules/regenerator-runtime": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", + "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==", + "dev": true, + "license": "MIT" + }, "node_modules/@babel/template": { - "version": "7.27.2", - "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", - "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", + "version": "7.26.9", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.26.9.tgz", + "integrity": "sha512-qyRplbeIpNZhmzOysF/wFMuP9sctmh2cFzRAZOn1YapxBsE1i9bJIY586R/WBLfLcmcBlM8ROBiQURnnNy+zfA==", "dev": true, "license": "MIT", "dependencies": { - "@babel/code-frame": "^7.27.1", - "@babel/parser": "^7.27.2", - "@babel/types": "^7.27.1" + "@babel/code-frame": "^7.26.2", + "@babel/parser": "^7.26.9", + "@babel/types": "^7.26.9" }, "engines": { "node": ">=6.9.0" @@ -2180,14 +2189,14 @@ "license": "MIT" }, "node_modules/@babel/types": { - "version": "7.28.5", - "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.5.tgz", - "integrity": "sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==", + "version": "7.26.9", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.9.tgz", + "integrity": "sha512-Y3IR1cRnOxOCDvMmNiym7XpXQ93iGDDPHx+Zj+NM+rg0fBaShfQLkg+hKPaZCEvg5N/LeCo4+Rj/i3FuJsIQaw==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-string-parser": "^7.27.1", - "@babel/helper-validator-identifier": "^7.28.5" + "@babel/helper-string-parser": "^7.25.9", + "@babel/helper-validator-identifier": "^7.25.9" }, "engines": { "node": ">=6.9.0" @@ -3310,9 +3319,9 @@ } }, "node_modules/babel-plugin-module-resolver/node_modules/brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", + "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", "dev": true, "license": "MIT", "dependencies": { @@ -3468,9 +3477,7 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", - "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "version": "1.1.11", "dev": true, "license": "MIT", "dependencies": { @@ -3824,9 +3831,9 @@ } }, "node_modules/compression": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz", - "integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==", + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.0.tgz", + "integrity": "sha512-k6WLKfunuqCYD3t6AsuPGvQWaKwuLLh2/xHNcX4qE+vIfDNXpSqnrhwA7O53R7WVQUnt8dVAIW+YHr7xTgOgGA==", "dev": true, "license": "MIT", "dependencies": { @@ -3834,7 +3841,7 @@ "compressible": "~2.0.18", "debug": "2.6.9", "negotiator": "~0.6.4", - "on-headers": "~1.1.0", + "on-headers": "~1.0.2", "safe-buffer": "5.2.1", "vary": "~1.1.2" }, @@ -4814,9 +4821,9 @@ } }, "node_modules/image-size": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.1.tgz", - "integrity": "sha512-rH+46sQJ2dlwfjfhCyNx5thzrv+dtmBIhPHk0zgRUukHzZ/kRueTJXoYYsclBaKcSMBWuGbOFXtioLpzTb5euw==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.0.tgz", + "integrity": "sha512-4S8fwbO6w3GeCVN6OPtA9I5IGKkcDMPcKndtUlpJuCwu7JLjtj7JZpwqLuyY2nrmQT3AWsCJLSKPsc2mPBSl3w==", "dev": true, "license": "MIT", "dependencies": { @@ -5243,9 +5250,7 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "3.14.2", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.2.tgz", - "integrity": "sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==", + "version": "3.14.1", "dev": true, "license": "MIT", "dependencies": { @@ -6539,9 +6544,9 @@ } }, "node_modules/on-headers": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz", - "integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==", + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", + "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", "dev": true, "license": "MIT", "engines": { @@ -7125,9 +7130,9 @@ "license": "Python-2.0" }, "node_modules/react-native-builder-bob/node_modules/brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", + "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", "dev": true, "license": "MIT", "dependencies": { @@ -7198,9 +7203,9 @@ } }, "node_modules/react-native-builder-bob/node_modules/js-yaml": { - "version": "4.1.1", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", - "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", + "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index f0f7527f665b9..6a8dffb73fa08 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -360,7 +360,7 @@ const createInPlaceSoftmaxProgramInfo = ( let local_offset = local_idx * uniforms.elements_per_thread; let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset; let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'}; - var thread_max_vector = ${f32Type}(-3.4028234663852886e+38f); + var thread_max_vector = ${f32Type}(-3.402823e+38f); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } @@ -378,7 +378,7 @@ const createInPlaceSoftmaxProgramInfo = ( })()}; workgroupBarrier(); - var max_value = f32(-3.4028234663852886e+38f); + var max_value = f32(-3.402823e+38f); for (var i = 0u; i < ${WG}; i++) { max_value = max(thread_max[i], max_value); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index f6882280e91df..2056416873df5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -81,7 +81,7 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt // 6.2.4 in wgsl spec const threadMaxDecl = tensorTypeToWsglStorageType(transposedInput.dataType) === 'f32' - ? `var threadMax = ${valueType}(-3.4028234663852886e+38f);` + ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (shaderHelper: ShaderHelper) => ` var rowMaxShared : ${valueType}; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu new file mode 100644 index 0000000000000..b40fc2bf0eef8 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cu @@ -0,0 +1,215 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/attention.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" +#include "contrib_ops/rocm/bert/transformer_common.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/shared_inc/fpgeneric.h" +#include "core/providers/rocm/tunable/gemm.h" + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +constexpr int kPastSequenceLengthInputIndex = 6; +constexpr int kPastInputIndex = 4; +constexpr int kPresentOutputIndex = 1; + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Attention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ + Attention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +Attention::Attention(const OpKernelInfo& info) + : RocmKernel(info), AttentionBase(info, true), attn_type_(kAttention) { + using HipT = typename ToHipType::MappedType; + using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; + tunable_op_ = std::make_shared(); +} + +template +Status Attention::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* weights = context->Input(1); + const Tensor* bias = context->Input(2); + const Tensor* mask_index = context->Input(3); + const Tensor* past = context->Input(4); + const Tensor* attention_bias = context->Input(5); + const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); + + auto& device_prop = GetDeviceProp(); + RocmAttentionParameters attn; + ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), + weights->Shape(), + bias->Shape(), + mask_index, + past, + attention_bias, + &attn, + device_prop.maxThreadsPerBlock, + past_seq_len)); + ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention + ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(attn.batch_size); + output_shape[1] = static_cast(attn.sequence_length); + output_shape[2] = static_cast(attn.v_hidden_size); + Tensor* output = context->Output(0, output_shape); + + std::vector present_dims{ + 2, attn.batch_size, attn.num_heads, + past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, + attn.head_size}; + TensorShape present_shape(present_dims); + Tensor* present = context->Output(kPresentOutputIndex, present_shape); + + auto stream = Stream(context); + hipblasHandle_t hipblas = GetHipblasHandle(context); + + using HipT = typename ToHipType::MappedType; + using QkvProjectGeneric = GemmPermuteGenericPipeline; + using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline; + using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; + + ORT_RETURN_IF_ERROR(ClassifyAttentionMode(attn_type_, &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present})); + ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE || + attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE || + attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE || + attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE || + attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE); + + size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn); + size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn), + AttentionGeneric::GetWorkspaceNumBytes(&attn)); + if (GetTuningContext()->IsTunableOpEnabled()) { + shared_workspace_bytes = std::max(shared_workspace_bytes, AttentionTunableOp::GetWorkspaceNumBytes(&attn)); + } + + auto qkv_project_output = GetScratchBuffer(qkv_project_output_bytes, context->GetComputeStream()); + auto workspace = GetScratchBuffer(shared_workspace_bytes, context->GetComputeStream()); + + GemmPermuteParams gemm_permute_params; + { + auto& params = gemm_permute_params; + params.tuning_ctx = GetTuningContext(); + params.stream = context->GetComputeStream(); + params.handle = hipblas; + params.attention = &attn; + params.device_prop = &device_prop; + + params.input_buffer = reinterpret_cast(input->DataRaw()); + params.weight_buffer = reinterpret_cast(weights->DataRaw()); + params.bias_buffer = reinterpret_cast(bias->DataRaw()); + params.out_buffer = reinterpret_cast(qkv_project_output.get()); + params.ones = GetConstOnes(attn.batch_size * attn.sequence_length, stream); + params.workspace_buffer = reinterpret_cast(workspace.get()); + } + + ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params)); + auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params); + + // NOTE: GemmPermute always output 3BNSH, k_buffer and v_buffer can be treated as 2BNSH + if (nullptr != present) { + Strides dst_strides; // the output buffer is present Tensor, the buffer is the same + + int4 add_shape{2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size}; + HipT* add_dest = nullptr; // destination of concatenated data to present + const HipT* const add_src = k_buffer; // source of concatenated data to present + const auto add_src_strides = Strides::BNSHMemory( + 2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size); + + if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; + } else if (attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + + // We only need to copy past to present in this case. All other cases will be build the present incrementally + const int4 past_shape = {2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; + HipT* const past_dest = reinterpret_cast(present->MutableDataRaw()); + const HipT* const past_src = reinterpret_cast(past->DataRaw()); + const Strides past_src_strides = Strides::BNSHMemory( + 2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); + + ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, past_src, past_shape, past_src_strides.ForBNSHCoord(), + past_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } else if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; + } else if (attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + } + + ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, add_src, add_shape, add_src_strides.ForBNSHCoord(), + add_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + + // update pointers to present_k and present_v. TODO: switch to ConvertToOffsetedBufferViews + k_buffer = reinterpret_cast(present->MutableDataRaw()); + v_buffer = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(attn.batch_size, 0, 0, 0); + } + + // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax + const TransformerOptions* options = TransformerOptions::GetInstance(); + bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); + + GemmSoftmaxGemmPermuteParams gemm_softmax_gemm_permute_params; + { + auto& params = gemm_softmax_gemm_permute_params; + params.tuning_ctx = GetTuningContext(); + params.stream = context->GetComputeStream(); + params.handle = hipblas; + params.attention = &attn; + params.device_prop = &device_prop; + // FIXME: the params.scale seems to be different from AttentionParameters::scale; + params.scale = 1.0f / sqrt(static_cast(attn.head_size)); + // TODO: switch to ConvertToOffsetedBufferViews + params.q_buffer = q_buffer; + params.k_buffer = k_buffer; + params.v_buffer = v_buffer; + params.out_buffer = reinterpret_cast(output->MutableDataRaw()); + + if (attention_bias != nullptr) { + params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); + } + + if (mask_index != nullptr) { + params.mask_index_buffer = mask_index->Data(); + params.mask_index_dims = mask_index->Shape().AsShapeVector(); + } + + params.workspace_buffer = reinterpret_cast(workspace.get()); + } + + if (this->GetTuningContext()->IsTunableOpEnabled() && + !use_persistent_softmax) { + return (*std::static_pointer_cast(tunable_op_))(&gemm_softmax_gemm_permute_params); + } else { + return AttentionGeneric::Run(&gemm_softmax_gemm_permute_params, use_persistent_softmax); + } +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.h b/onnxruntime/contrib_ops/rocm/bert/attention.h new file mode 100644 index 0000000000000..7204fd660a516 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/rocm_kernel.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/rocm/bert/attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class Attention final : public RocmKernel, public AttentionBase { + public: + Attention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + public: + AttentionType attn_type_; + + // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: + // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. + // 2. We don't want to construct the object repeatly (which is expansive) during Compute. + std::shared_ptr tunable_op_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu new file mode 100644 index 0000000000000..270a8e51daf88 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -0,0 +1,435 @@ +/* + The implementation of this file is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/shared_inc/fpgeneric.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/attention_softmax.h" +#include "contrib_ops/rocm/bert/decoder_attention_impl.h" + +using namespace onnxruntime::rocm; + +namespace blas = onnxruntime::rocm::tunable::blas; + +#define CHECK_ROCM(expr) HIP_RETURN_IF_ERROR(expr) + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +static size_t AlignTo(size_t a, size_t b) { + return CeilDiv(a, b) * b; +} + +size_t GetAttentionScratchSize(size_t element_size, + int batch_size, + int num_heads, + int sequence_length, + int total_sequence_length) { + const size_t bytes = element_size * batch_size * num_heads * sequence_length * total_sequence_length; + + const size_t alignment = 256; + const size_t bytesAligned = AlignTo(bytes, alignment); + return bytesAligned; +} + +size_t GetAttentionWorkspaceSize( + size_t element_size, + int batch_size, + int num_heads, + int head_size, + int sequence_length, + int total_sequence_length) { + size_t qkv_size = element_size * 3 * batch_size * sequence_length * num_heads * head_size; + return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, + sequence_length, total_sequence_length); +} + +inline int3 Get2DMaskStrides(int total_sequence_length) { + // stride == 0 indicate broadcasting + return {total_sequence_length, 0, 1}; +} + +Status ClassifyAttentionMode( + AttentionType attn_type, + RocmAttentionParameters* attn, + const std::vector& qkv, + const std::vector& past, + const std::vector& present) { + size_t num_qkv = std::count_if(qkv.cbegin(), qkv.cend(), [](auto it) { return it != nullptr; }); + size_t num_past = std::count_if(past.cbegin(), past.cend(), [](auto it) { return it != nullptr; }); + size_t num_present = std::count_if(present.cbegin(), present.cend(), [](auto it) { return it != nullptr; }); + + auto hint = MakeString(num_qkv, " qkv inputs, ", num_past, " past inputs and ", num_present, " present inputs"); + LOGS_DEFAULT(VERBOSE) << hint; + + if (attn_type == kAttention) { + ORT_ENFORCE(num_qkv == 0); + if (num_past == 0 && num_present == 0) { + attn->mode = QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE; + return Status::OK(); + } else if (num_past == 0 && num_present == 1) { + if (attn->past_present_share_buffer == false) { + attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE; + return Status::OK(); + } else { + attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE; + return Status::OK(); + } + } else if (num_past == 1 && num_present == 1) { + if (attn->past_present_share_buffer == false) { + attn->mode = QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE; + return Status::OK(); + } else { + attn->mode = QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE; + return Status::OK(); + } + } + } else if (attn_type == kMultiHeadAttention || attn_type == kDecoderMaskedMultiHeadAttention) { + if (num_qkv == 3 && num_past == 0 && num_present == 0) { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; + return Status::OK(); + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; + return Status::OK(); + } + } else if (num_qkv == 3 && num_past == 0 && num_present == 2) { + if (attn->past_present_share_buffer == false) { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; + return Status::OK(); + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; + return Status::OK(); + } + } else { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; + return Status::OK(); + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; + return Status::OK(); + } + } + } else if (num_qkv == 3 && num_past == 2 && num_present == 2) { + if (attn->past_present_share_buffer == false) { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; + return Status::OK(); + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; + return Status::OK(); + } + } else { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; + return Status::OK(); + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; + return Status::OK(); + } + } + } else if (num_qkv == 1 && num_past == 0 && num_present == 0) { + if (attn->qkv_format == QKV_BSN3H) { + attn->mode = BLN3H_NONE_NONE_NONE_NONE_NONE_NONE; + return Status::OK(); + } + } else if (num_qkv == 2 && num_past == 0 && num_present == 0) { + if (attn->qkv_format == Q_KV_BSNH_BSN2H) { + attn->mode = BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE; + return Status::OK(); + } + } + } + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported AttentionMode for ", attn_type, ". Got qkv format ", attn->qkv_format, + ". Got ", hint); +} + +template +Status DecoderQkvToContext( + const hipDeviceProp_t& prop, + RocmTuningContext* tuning_ctx, + Stream* ort_stream, + hipblasHandle_t& hipblas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const T* gemm_query_buffer, + const T* gemm_kv_buffer, + const bool* key_padding_mask, + const T* key_cache, + const T* value_cache, + T* qkv_buffer, + T* workspace_buffer, + T* output, + T* new_key_cache, + T* new_value_cache) { + const int max_threads_per_block = prop.maxThreadsPerBlock; + const int BN = batch_size * num_heads; + const int BHN = BN * head_size; + const int BNS = BN * sequence_length; + const int k_buffer_offset = sequence_length * BHN; + const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + + T* temp_qkv_buffer = workspace_buffer; + auto stream = static_cast(ort_stream->GetHandle()); + + const T* q = qkv_buffer; + // transpose q and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, + num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); + + const T* k = qkv_buffer + k_buffer_offset; + const T* v = qkv_buffer + v_buffer_offset; + if (!has_layer_state || !use_past) { + if (!static_kv) { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } else { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } + } else { + if (!static_kv) { + // transpose kv and copy them to temp_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); + // concat cache-k with k and copy to qkv_buffer + if (nullptr != key_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, + batch_size, head_size, num_heads, + max_threads_per_block, 1, key_cache, + temp_qkv_buffer, qkv_buffer + k_buffer_offset)); + } + // concat cache-v with v and copy to qkv_buffer + if (nullptr != value_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, + batch_size, head_size, num_heads, + max_threads_per_block, 1, value_cache, + temp_qkv_buffer + k_buffer_offset, + qkv_buffer + v_buffer_offset)); + } + } + } + + if (has_layer_state) { + if (use_past && static_kv) { + CHECK_ROCM(hipMemcpyAsync(new_key_cache, key_cache, + kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); + CHECK_ROCM(hipMemcpyAsync(new_value_cache, value_cache, + kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); + } else { + CHECK_ROCM(hipMemcpyAsync(new_key_cache, k, + kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); + CHECK_ROCM(hipMemcpyAsync(new_value_cache, v, + kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); + } + } + + // scratch1: BxNxSxS* buffer + // scratch2: BxNxSxS* buffer + // scratch3: BxNxSxH buffer + T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; + T* scratch2 = scratch1 + BNS * kv_sequence_length; + T* scratch3 = scratch2 + BNS * kv_sequence_length; + + // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* + // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* + const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); + const int temp_matrix_size = sequence_length * kv_sequence_length; + + const int strideA = kv_sequence_length * head_size; + const int strideB = sequence_length * head_size; + if (use_past && static_kv) { + ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( + tuning_ctx, ort_stream, hipblas, + blas::BlasOp::Trans, blas::BlasOp::NonTrans, + kv_sequence_length, sequence_length, head_size, + /*alpha=*/rsqrt_head_size, + key_cache, head_size, strideA, + q, head_size, strideB, + /*beta=*/0.0f, + scratch1, kv_sequence_length, temp_matrix_size, + BN)); + } else { + ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( + tuning_ctx, ort_stream, hipblas, + blas::BlasOp::Trans, blas::BlasOp::NonTrans, + kv_sequence_length, sequence_length, head_size, + /*alpha=*/rsqrt_head_size, + k, head_size, strideA, + q, head_size, strideB, + /*beta=*/0.0f, + scratch1, kv_sequence_length, temp_matrix_size, + BN)); + } + + if (has_key_padding_mask) { + int3 strides = Get2DMaskStrides(kv_sequence_length); + ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( + ort_stream, kv_sequence_length, sequence_length, batch_size, num_heads, + strides, nullptr, key_padding_mask, nullptr, scratch1, scratch2, + false, 1.0f, false, nullptr, mask_filter_value)); + } else { + ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, + num_heads, nullptr, scratch1, scratch2, false)); + } + + // compute P*V (as V*P), and store in scratch3: BxNxSxH + if (use_past && static_kv) { + ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( + tuning_ctx, ort_stream, hipblas, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + head_size, sequence_length, kv_sequence_length, + /*alpha=*/1.0f, + value_cache, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + /*beta=*/0.0f, + scratch3, head_size, strideB, + BN)); + } else { + ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( + tuning_ctx, ort_stream, hipblas, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + head_size, sequence_length, kv_sequence_length, + /*alpha=*/1.0f, + v, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + /*beta=*/0.0f, + scratch3, head_size, strideB, + BN)); + } + + // scratch3 is BxNxSxH, transpose to output SxBxNxH + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, + num_heads, max_threads_per_block, true, scratch3, output); +} + +Status LaunchDecoderAttentionKernel( + const hipDeviceProp_t& prop, + RocmTuningContext* tuning_ctx, + Stream* stream, + hipblasHandle_t& hipblas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const void* gemm_query_buffer, + const void* gemm_kv_buffer, + const bool* key_padding_mask, + const void* key_cache, + const void* value_cache, + void* qkv_buffer, + void* workspace_buffer, + void* output, + void* new_key_cache, + void* new_value_cache) { + if (element_size == 2) { + return DecoderQkvToContext( + prop, + tuning_ctx, + stream, + hipblas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } else { + return DecoderQkvToContext( + prop, + tuning_ctx, + stream, + hipblas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h new file mode 100644 index 0000000000000..07d875e90fa4b --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cpu/bert/attention_parameters.h" +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +typedef struct __align__(32) { + long long int x, y, z, w; +} LongLong4; + +size_t GetAttentionScratchSize( + size_t element_size, + int batch_size, + int num_heads, + int sequence_length, + int all_sequence_length); + +size_t GetAttentionWorkspaceSize( + size_t element_size, + int batch_size, + int num_heads, + int head_size, + int sequence_length, + int past_sequence_length); + +Status LaunchTransCtx(hipStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); + +Status LaunchTransCtx(hipStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, const half* input, half* output); + +Status LaunchTransQkv(hipStream_t stream, const int matrix_num, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, const float* input, float* output, + int total_matrix_count = -1); + +Status LaunchTransQkv(hipStream_t stream, const int matrix_num, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, + int total_matrix_count = -1); + +Status LaunchConcatTensorToTensor(hipStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const float* tensor_in, + const float* tensor_add, + float* tensor_out); + +Status LaunchConcatTensorToTensor(hipStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const half* tensor_in, + const half* tensor_add, + half* tensor_out); + +inline hipblasStatus_t _compat_hipblas_gemm_strided_batched_ex(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const void* alpha, + const void* A, + hipDataType a_type, + int lda, + hipblasStride stride_A, + const void* b, + hipDataType b_type, + int ldb, + hipblasStride stride_b, + const void* beta, + void* c, + hipDataType c_type, + int ldc, + hipblasStride stride_c, + int batch_count, + hipblasComputeType_t compute_type, + hipblasGemmAlgo_t algo) { + return hipblasGemmStridedBatchedEx(handle, + transa, + transb, + m, // m + n, // n + k, // k + alpha, // alpha + A, // A + a_type, // A type + lda, // lda + stride_A, // strideA + b, // B + b_type, // B type + ldb, // ldb + stride_b, // strideB + beta, // beta + c, // C + c_type, // C type + ldc, // ldc + stride_c, // strideC + batch_count, // batch count + compute_type, + algo); +} + +// Compatible for CublasMathModeSetter +class CompatHipblasMathModeSetter { + public: + CompatHipblasMathModeSetter(const hipDeviceProp_t&, + hipblasHandle_t, + int) { + } +}; + +enum AttentionMode { + // Q,K,V,PastK,PastV,PresentK,PresentV + QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, + QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE, + QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE, + QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE, + QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE, + BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE, + BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE, + BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH, + BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH, + BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH, + BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH, + BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH, + BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH, + BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH, + BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH, + BLN3H_NONE_NONE_NONE_NONE_NONE_NONE, + BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE, +}; + +struct RocmAttentionParameters : AttentionParameters { + AttentionMode mode; +}; + +Status ClassifyAttentionMode(AttentionType type, + RocmAttentionParameters* attn, + const std::vector& qkv, + const std::vector& past, + const std::vector& present); + +template +Status LaunchStridedCopy( + hipStream_t stream, + const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); + +template +Status LaunchStridedCopy(hipStream_t stream, + const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h) + T* out, LongLong4 out_strides, // coord (b,n,s,h) + int max_threads_per_block); +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h new file mode 100644 index 0000000000000..9f2faa228cf79 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -0,0 +1,465 @@ +#include "hip/hip_runtime.h" +/* + The implementation of this file is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include +#include +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/math/softmax.h" + +#define ROCMRT_INF_F __int_as_float(0x7f800000) + +using namespace onnxruntime::rocm; +using namespace hipcub; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +__device__ inline void Softmax(const int all_sequence_length, + const int valid_end, + const int valid_start, + const T* attn_bias, + const T* input, + T* output) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + float thread_data_max(-ROCMRT_INF_F); + + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; + for (int i = threadIdx.x; i < valid_end; i += TPB) { + if (i >= valid_start) { + const int index = offset + i; + float input_at_idx = attn_bias == nullptr + ? static_cast(input[index]) + : static_cast(input[index] + attn_bias[index]); + if (thread_data_max < input_at_idx) { + thread_data_max = input_at_idx; + } + } + } + + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max()); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_sum(0.f); + for (int i = threadIdx.x; i < valid_end; i += TPB) { + if (i >= valid_start) { + const int index = offset + i; + float val = attn_bias == nullptr ? input[index] : input[index] + attn_bias[index]; + thread_data_sum += expf(val - max_block); + } + } + + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, hipcub::Sum()); + if (threadIdx.x == 0) { + sum_reverse_block = 1.f / sum; + } + __syncthreads(); + + for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { + const int index = offset + i; + float input_at_idx = attn_bias == nullptr + ? static_cast(input[index]) + : static_cast(input[index] + attn_bias[index]); + const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; + output[index] = T(val); + } +} + +template +__device__ inline void SoftmaxSmall(const int all_sequence_length, + const int sequence_length, + const int valid_end, + const int valid_start, + const T* attn_bias, + const T* input, + T* output, + bool causal) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; + const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; + const int index = offset + threadIdx.x; + + bool is_valid = false; // whether it has attention mask == 1. + + // Update end position for causal. + int end = valid_end; + if (causal) { + const int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; + if (end_causal < end) { + end = end_causal; + } + } + + is_valid = (threadIdx.x >= valid_start && threadIdx.x < end); + + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + float input_data = attn_bias == nullptr + ? static_cast(input[index]) + : static_cast(input[index] + attn_bias[index]); + float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F); + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_exp(0.f); + if (is_valid) { + thread_data_exp = expf(input_data - max_block); + } + + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), end); + + // Store value of 1.0/sum. + if (threadIdx.x == 0) { + sum_reverse_block = (1.f) / sum; + } + __syncthreads(); + + // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. + if (threadIdx.x < all_sequence_length) { + output[index] = is_valid ? T(thread_data_exp * sum_reverse_block) : T(0.f); + } +} + +// Note about the attention_mask_strides and attention_mask/key_padding_mask +// attention_mask accepts 2D, 3D or 4D tensor, but it will be viewed as 3D tensor uniformally and it will be indexed +// as [batch_index, sequence_index, token_index]. +template +__global__ void SoftmaxWithRawMaskSmallKernel( + const int all_sequence_length, + const int sequence_length, + const int3 attention_mask_strides, + const int* attention_mask, // 2D, 3D or 4D attention mask + const bool* key_padding_mask, + const T* attn_bias, + const T* input, + T* output, + const bool causal, + const float rsqrt_head_size, + const bool skip_softmax, + const float mask_filter_value) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; + int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; + + // Mask all thread_data values to negative infinity to allow BlockReduce Max operation over all thread_data + // members with all invalid members set to a value that does not impact the final result. This is necessary + // to avoid the performance impact from using the valid_items interface. + float thread_data = -ROCMRT_INF_F; + if (threadIdx.x < all_sequence_length) { + thread_data = float(input[index]) * rsqrt_head_size; + + const int sequence_index = blockIdx.x % sequence_length; + if (causal) { + int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. + if (threadIdx.x > from_index) { + thread_data = -ROCMRT_INF_F; + } + } + + const int batch_index = blockIdx.y; + int mask_offset = attention_mask_strides.x * batch_index + + attention_mask_strides.y * sequence_index + + attention_mask_strides.z * threadIdx.x; + + if (nullptr == key_padding_mask) { + const int& mask = attention_mask[mask_offset]; + if (mask == 0) + thread_data += mask_filter_value; + } else { + const bool mask = key_padding_mask[mask_offset]; + if (mask) { + thread_data = -ROCMRT_INF_F; + } + } + + if (attn_bias != nullptr) { + thread_data += float(attn_bias[index]); + } + } + + if (skip_softmax) { + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data); + } + return; + } + + const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max()); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + // Mask all thread_data_exp values to zero to allow BlockReduce Sum operation over all thread_data_exp + // members with all invalid members set to a value that does not impact the final result. This is necessary + // to avoid the performance impact from using the valid_items interface. + float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum()); + + // Store value of 1.0/sum + if (threadIdx.x == 0) { + sum_reverse_block = (1.f) / sum; + } + __syncthreads(); + + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data_exp * sum_reverse_block); + } +} + +template +__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, + const T* attn_bias, const T* input, T* output, bool causal) { + SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, + attn_bias, input, output, causal); +} + +template +__global__ void SoftmaxKernel(const int all_sequence_length, const T* attn_bias, const T* input, T* output) { + Softmax(all_sequence_length, all_sequence_length, 0, attn_bias, input, output); +} + +template +Status ComputeSoftmax( + hipStream_t stream, + const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const T* attn_bias, const T* input, T* output, bool causal) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + if (all_sequence_length <= 32) { + const int blockSize = 32; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (all_sequence_length <= 128) { + const int blockSize = 128; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (all_sequence_length <= 256) { + const int blockSize = 256; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (!causal) { + const int blockSize = 1024; + SoftmaxKernel<<>>( + all_sequence_length, attn_bias, input, output); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); + } + + return HIP_CALL(hipPeekAtLastError()); +} + +template +__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, + const int* mask_end, const int* mask_start, + const T* attn_bias, const T* input, T* output, + bool causal) { + __shared__ int start_position; + __shared__ int end_position; + + if (threadIdx.x == 0) { + const int batch = blockIdx.y; + start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; + end_position = min(all_sequence_length, mask_end[batch]); + + // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. + if (start_position >= end_position) { + start_position = 0; + end_position = all_sequence_length; + } + } + __syncthreads(); + + SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, + attn_bias, input, output, causal); +} + +template +__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int* mask_end, const int* mask_start, + const T* attn_bias, const T* input, T* output) { + __shared__ int start_position; + __shared__ int end_position; + + if (threadIdx.x == 0) { + const int batch = blockIdx.y; + start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; + end_position = min(all_sequence_length, mask_end[batch]); + + // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. + if (start_position >= end_position) { + start_position = 0; + end_position = all_sequence_length; + } + } + __syncthreads(); + + Softmax(all_sequence_length, end_position, start_position, attn_bias, input, output); +} + +template +Status ComputeSoftmaxWithMask1D( + hipStream_t stream, + const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const int* mask_index, const int* mask_start, + const T* attn_bias, const T* input, T* output, const bool causal) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + +#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ + MaskedSoftmaxKernelSmall<<>>( \ + all_sequence_length, sequence_length, mask_index, mask_start, \ + attn_bias, input, output, causal); + + if (all_sequence_length <= 32) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); + } else if (all_sequence_length <= 64) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); + } else if (all_sequence_length <= 128) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); + } else if (all_sequence_length <= 256) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); + } else if (all_sequence_length <= 512) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); + } else if (all_sequence_length <= 1024) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); + } else if (!causal) { + const int blockSize = 1024; + MaskedSoftmaxKernel<<>>( + all_sequence_length, mask_index, mask_start, + attn_bias, input, output); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); + } + +#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE + + return HIP_CALL(hipPeekAtLastError()); +} + +template +Status ComputeSoftmaxWithRawMask(Stream* ort_stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int num_heads, + const int3 attention_mask_strides, + const int* attention_mask, + const bool* key_padding_mask, + const T* attn_bias, + const T* input, + T* output, + const bool causal, + const float rsqrt_head_size, + const bool use_persistent_softmax, + T* persistent_softmax_workspace, + const float mask_filter_value) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + + T* out = use_persistent_softmax ? persistent_softmax_workspace : output; + auto stream = static_cast(ort_stream->GetHandle()); + +#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ + SoftmaxWithRawMaskSmallKernel<<>>( \ + all_sequence_length, sequence_length, attention_mask_strides, \ + attention_mask, key_padding_mask, attn_bias, input, out, \ + causal, rsqrt_head_size, \ + use_persistent_softmax, mask_filter_value); + + if (all_sequence_length <= 32) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); + } else if (all_sequence_length <= 64) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); + } else if (all_sequence_length <= 128) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); + } else if (all_sequence_length <= 256) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); + } else if (all_sequence_length <= 512) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); + } else if (all_sequence_length <= 1024) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); + } + +#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE + + if (use_persistent_softmax) { + return dispatch_warpwise_softmax_forward(ort_stream, + output, + persistent_softmax_workspace, + all_sequence_length, + all_sequence_length, + batch_size * num_heads * sequence_length); + } + + return HIP_CALL(hipPeekAtLastError()); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh new file mode 100644 index 0000000000000..213940f132963 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/rocm_kernel.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cpu/bert/attention_parameters.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +namespace blas = onnxruntime::rocm::tunable::blas; + +namespace { +std::tuple GetQkvProjectGemmMNKBatch(const AttentionParameters* attention) { + int m = attention->sequence_length; + int n = (attention->hidden_size + attention->hidden_size + attention->v_hidden_size); // q + k + v + int k = attention->input_hidden_size; + int batch = attention->batch_size; + return {m, n, k, batch}; +} +} // namespace + +template +struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams { + std::string Signature() const override { + auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(attention); + return MakeString("M", m, "_N", n, "_K", k, "_B", batch); + } + + hipblasHandle_t handle; + const AttentionParameters* attention; + const hipDeviceProp_t* device_prop; + + const T* input_buffer; + const T* weight_buffer; + const T* bias_buffer; + T* out_buffer; + + int3 bias_strides; + + const T* ones; // used for broadcasting bias if the underlying algorithm does not support strides + T* workspace_buffer; +}; + +template +struct GemmPermuteGenericPipeline { + inline static size_t GetOutputNumBytes(const AttentionParameters* attn) { + auto [m, n, _, batch] = GetQkvProjectGemmMNKBatch(attn); + return sizeof(T) * m * n * batch; + } + + inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) { + return GetOutputNumBytes(attn); + } + + inline static std::tuple GetGemmMNK(const GemmPermuteParams* params) { + auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(params->attention); + return {batch * m, n, k}; + } + + inline static std::tuple UnspliceOutputQKV(const GemmPermuteParams* params) { + auto* attn = params->attention; + int64_t batch = attn->batch_size * attn->num_heads; + int64_t num_elems_per_batch = attn->sequence_length * attn->head_size; + int64_t num_elems = batch * num_elems_per_batch; + auto q = params->out_buffer + 0 * num_elems; + auto k = params->out_buffer + 1 * num_elems; + auto v = params->out_buffer + 2 * num_elems; + return {q, k, v}; + } + + inline static Status BroadcastBias(const GemmPermuteParams* params) { + auto [m, n, k] = GetGemmMNK(params); + // Bias shape is (N), broadcast using B(M, N) = ones(M, 1) x bias(1, N). + // TODO: use custom kernel of expand to improve the performance. + return blas::row_major::Gemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + m, n, 1, + /*alpha=*/1.0f, + params->ones, 1, + params->bias_buffer, n, + /*beta=*/0.0f, + params->workspace_buffer, n); + } + + inline static Status Gemm(const GemmPermuteParams* params) { + auto [m, n, k] = GetGemmMNK(params); + // result(M, N) = input x weights + bias. + return blas::row_major::Gemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + m, n, k, + /*alpha=*/1.0f, + params->input_buffer, k, + params->weight_buffer, n, + /*beta=*/1.0f, + params->workspace_buffer, n); + } + + inline static Status Permute0213(const GemmPermuteParams* params) { + auto* attn = params->attention; + // input should be BxSx3xNxH => gemm_buffer: 3xBxNxSxH + return LaunchTransQkv( + params->StreamHandle(), 3, attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, + params->device_prop->maxThreadsPerBlock, false, params->workspace_buffer, params->out_buffer); + } + + static Status Run(const GemmPermuteParams* params) { + ORT_RETURN_IF_ERROR(BroadcastBias(params)); + ORT_RETURN_IF_ERROR(Gemm(params)); + ORT_RETURN_IF_ERROR(Permute0213(params)); + return Status::OK(); + } +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh new file mode 100644 index 0000000000000..be8508670e4b1 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifdef USE_COMPOSABLE_KERNEL +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecialization; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface +using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; + +template +using device_batched_gemm_softmax_gemm_permute_instances = + std::tuple< + // clang-format off + // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias| + // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar| + // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector| + // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, +#if ROCM_VERSION >= 50500 + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, +#endif + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + // Padded fallback kernel + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + // clang-format on + >; + +struct PreSoftmaxAttentionScoreOp { + PreSoftmaxAttentionScoreOp(float scale) : scale_(scale) {} + + // non-biased, non-masked + __host__ __device__ void operator()(float& y, const float& x) const { + y = scale_ * x; + } + + // biased or converted masked + __host__ __device__ void operator()(float& y, const float& x, const F16& bias) const { + y = scale_ * x + ck::type_convert(bias); + } + + // biased and converted masked + __host__ __device__ void operator()(float& y, const float& x, const F16& bias, const F16& converted_mask) const { + y = scale_ * x + ck::type_convert(bias) + ck::type_convert(converted_mask); + } + + float scale_; +}; + +// Use this function to gat implementation +template +std::vector, + PassThrough, PassThrough, D0Op, PassThrough, PassThrough, + MaskingSpec>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances() { + return {}; +} + +// implemented in impl_{fp16,bf16}[_biased][_masked].cu +// fp16, non-biased, non-masked +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); + +// fp16, biased, non-masked +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); + +// fp16, biased, fp16 masked, basically, two bias +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); + +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, non-masked +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, fp16 masked, basically, two bias +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu new file mode 100644 index 0000000000000..2e32a6594d164 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using NonBiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskDisabled>{}); + + return instances; +} + +using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu new file mode 100644 index 0000000000000..91da8d9e1f9a8 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskDisabled>{}); + + return instances; +} + +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu new file mode 100644 index 0000000000000..b08123be18977 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskDisabled>{}); + + return instances; +} + +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh new file mode 100644 index 0000000000000..226b89cfb2b86 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -0,0 +1,915 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +/* About Computing in these Pipelines + +B: batch size of Attention Op. NOTE: To be disambiguated with batch size of GEMMs +S: sequence length +T: total sequence length +N: num of heads +H: head dimension + +The following use qkv_format == Q_K_V_BNSH (mode == BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE) as a example: + +BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op + +In QKV projection (prior to this pipeline): + /-> Q [B,S,N*H] ->Reshape-> [B,S,N,H] ->Permute0213-> [B,N,S,H] +X --o--> K [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] + \-> V [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] + +pre_softmax_attn_scores = Q*K' = [B,N,S,H] * [BxNxTxH]' = [B,N,S,T] Batched GEMM1 +pre_softmax_attn_scores_masked = pre_softmax_attn_scores * scale +? bias +? mask Scale Add Bias, +? is optional +attn_scores = softmax(pre_softmax_attn_scores_masked) = [B,N,S,T] Softmax +scaled_multi_head_attn = attn_scores * V = [B,N,S,T] * [B,N,T,H] = [B,N,S,H] Batched GEMM2 + +Op outputs scaled_multi_head_attn: +[B,N,S,H] ->Permute0213-> [B,S,N,H] ->Reshape-> [B,S,N*H] + + +For the computing of pre_softmax_attn_scores +? mask +? bias: + +GemmSoftmaxGemmPermuteGenericPipeline handles it in specialized softmax. TODO: remove it! + +CK in GemmSoftmaxGemmPermuteTunablePipeline + + Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked + bias --------------> [B,N,S,T] --+?--/ +mask_2d ---> [B,T] ---> [B,1,1,T] -/ + + Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked + bias --------------> [B,N,S,T] --+?--/ +mask_3d --> [B,S,T] --> [B,1,S,T] -/ + + Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked + bias --------------> [B,N,S,T] --+?--/ +mask_4d -> [B,1,M,M] -> [B,1,S,T] -/ M is max_sequence_length from megatron, we will create a + **sub-view** from original mask buffer + +For CK implementation, there will be four cases combined: +non-biased, non-masked, no special processing. + biased, non-masked, no special processing, add the mask directly. +non-biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with scaled Q*K'. + biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with bias and scaled Q*K'. + +Broadcast add is not actually perform the broadcasting, just broadcast the load operation from memory. The impl details +are in composable kernels. The scale and add logic is performed via Acc0ElementOp + +# Classified modes: + +| Q | K | V | past(K)| pastV | present(K)| presentV | Op, desc +| ---- | ---- | ---- | ------ | ----- | --------- | -------- | --------- +| QFMT | KFMT | VFMT | - | - | - | - | A, basic, qkv is impl dependent by qkv_format +| QFMT | KFMT | VFMT | 2BNPH | - | 2BNTH *^ | - | A, past_present_share_buffer = false, qkv is impl dependent by qkv_format +| QFMT | KFMT | VFMT | 2BNMH | - | 2BNMH *^ | - | A, past_present_share_buffer = true, qkv is impl dependent by qkv_format +| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic +| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true +| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false +| BSNH | BLNH | BLNH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false +| BSNH | BNLH | BNLH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false +| BSNH | BLNH | BLNH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true +| BSNH | BNLH | BNLH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true +| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false +| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false +| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true +| BSNH | BNLH | BNLH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true +| BLN3H*^| - | - | - | - | - | - | MHA basic, qkv_packed +| BSNH | BLN2H*^| - | - | - | - | - | MHA basic, kv_packed + +Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel + +About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_v for v_buffer + +- Marked with `*` indicate the Tensor is used for k_buffer passing. +- Marked with `^` indicate the Tensor is used for v_buffer passing. + +# Supported Op + +- A: Attention +- MHA: MultiHeadAttention + +# Dim Value + +- B: batch_size +- N: num_heads +- H: head_size + +- S: sequence_length +- L: kv_sequence_length +- P: past_sequence_length +- T: total_sequence_length = P + L +- M: max_sequence_length +*/ + +#include "core/framework/tensor_shape.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/attention_softmax.h" +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" +#include "core/providers/rocm/composable_kernel_common.h" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#endif // USE_COMPOSABLE_KERNEL + +#include +#include + +namespace blas = onnxruntime::rocm::tunable::blas; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +inline int3 Get2DMaskStrides(int total_sequence_length) { + // stride == 0 indicate broadcasting + return {total_sequence_length, 0, 1}; +} + +// A stride maps from natural coordinate to physical offset of underlying memory storage buffer offset. We need to +// specify both of the natural coordinate order, say (b,n,s,h), (b,s,n,h) or (b,n,h,s), and memory order, say BNSH or +// BSNH, to determain the strides. To obtain the offset, we just do the inner product of coordinate with the strides. +// This wrapper create the stride vector from the physical dimension (or physical shape). +struct Strides { + // Create the strides for BNSH physically indexed memory buffer + static Strides BNSHMemory(int batch_dim, + int num_head_dim, + int seqlen_dim, + int head_size_dim) { + ORT_UNUSED_PARAMETER(batch_dim); + return Strides{LongLong4{ + static_cast(num_head_dim) * seqlen_dim * head_size_dim, + static_cast(seqlen_dim) * head_size_dim, + static_cast(head_size_dim), + static_cast(1), + }}; + } + + // Create the strides for BSNH physically indexed memory buffer + static Strides BSNHMemory(int batch_dim, + int seqlen_dim, + int num_head_dim, + int head_size_dim) { + ORT_UNUSED_PARAMETER(batch_dim); + return Strides{LongLong4{ + static_cast(seqlen_dim) * num_head_dim * head_size_dim, + static_cast(head_size_dim), + static_cast(num_head_dim) * head_size_dim, + static_cast(1), + }}; + } + + template + T ForBNSHCoord() const { + using E = typename T::value_type; + return T{static_cast(strides_for_bnsh_coord.x), + static_cast(strides_for_bnsh_coord.y), + static_cast(strides_for_bnsh_coord.z), + static_cast(strides_for_bnsh_coord.w)}; + } + + template + T ForBSNHCoord() const { + using E = typename T::value_type; + return T{static_cast(strides_for_bnsh_coord.x), + static_cast(strides_for_bnsh_coord.z), + static_cast(strides_for_bnsh_coord.y), + static_cast(strides_for_bnsh_coord.w)}; + } + + template + T ForBNHSCoord() const { + using E = typename T::value_type; + return T{static_cast(strides_for_bnsh_coord.x), + static_cast(strides_for_bnsh_coord.y), + static_cast(strides_for_bnsh_coord.w), + static_cast(strides_for_bnsh_coord.z)}; + } + + int64_t OffsetAt(int b, int n, int s, int h) const { + return strides_for_bnsh_coord.x * b + strides_for_bnsh_coord.y * n + + strides_for_bnsh_coord.z * s + strides_for_bnsh_coord.w * h; + } + + // store intermediate strides in the canonical (b,n,s,h) coordinate order + LongLong4 strides_for_bnsh_coord; +}; + +template +std::tuple ConvertToOffsetedBufferViews( + const RocmAttentionParameters* attn, + const T* query = nullptr, // q or packed_qkv + const T* key = nullptr, // k or packed kv + const T* value = nullptr, // + const T* present = nullptr, // present or present_k + const T* present_v = nullptr) { + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: { + return {reinterpret_cast(query), + reinterpret_cast(key), + reinterpret_cast(value)}; + } + case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: + case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: { + auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->total_sequence_length * + attn->head_size; + return {reinterpret_cast(query), + reinterpret_cast(present), + reinterpret_cast(present) + offset}; + } + case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: + case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: { + auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->max_sequence_length * + attn->head_size; + return {reinterpret_cast(query), + reinterpret_cast(present), + reinterpret_cast(present) + offset}; + } + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + return {reinterpret_cast(query), + reinterpret_cast(present), + reinterpret_cast(present_v)}; + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: { + auto packed_kv = reinterpret_cast(key); + return {reinterpret_cast(query), packed_kv, packed_kv + attn->head_size}; + } + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: { + auto packed_qkv = reinterpret_cast(query); + return {packed_qkv, packed_qkv + 1 * attn->head_size, packed_qkv + 2 * attn->head_size}; + } + default: + ORT_ENFORCE("unreachable"); + return {}; + } +} + +inline std::tuple GetQkvStrides(const RocmAttentionParameters* attn) { + // G0 not used, because it is the slowest dimension + const int& B = attn->batch_size; + const int& N = attn->num_heads; + const int& S = attn->sequence_length; + const int& L = attn->kv_sequence_length; + const int& T = attn->total_sequence_length; + const int& M = attn->max_sequence_length; + const int& H = attn->head_size; + const int& Hv = attn->v_head_size; + + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + if (attn->qkv_format == Q_K_V_BNSH) { + return { + Strides::BNSHMemory(B, N, S, H), + Strides::BNSHMemory(B, N, L, H), + Strides::BNSHMemory(B, N, L, Hv), + }; + } else if (attn->qkv_format == Q_K_V_BSNH) { + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BSNHMemory(B, L, N, H), + Strides::BSNHMemory(B, L, N, Hv), + }; + } + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BNSHMemory(B, N, T, H), + Strides::BNSHMemory(B, N, T, Hv), + }; + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BNSHMemory(B, N, M, H), + Strides::BNSHMemory(B, N, M, Hv), + }; + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BSNHMemory(B, L, N, H), + Strides::BSNHMemory(B, L, N, Hv), + }; + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BNSHMemory(B, N, L, H), + Strides::BNSHMemory(B, N, L, Hv), + }; + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BSNHMemory(B, L, N, 2 * H), + Strides::BSNHMemory(B, L, N, 2 * Hv), + }; + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, L, N, 3 * H), + Strides::BSNHMemory(B, L, N, 3 * H), + Strides::BSNHMemory(B, L, N, 3 * Hv), + }; + default: + ORT_ENFORCE("unreachable"); + return {}; + } +} + +inline std::tuple GetRawMaskBufferAddrSizesAndStrides( + const int* buffer, const RocmAttentionParameters* attn) { + const int* offseted_buffer{buffer}; // how to view the mask buffer + int3 sizes{0, 0, 0}; // the logical shape of the view + int3 strides{-1, -1, -1}; // the physical memory layout + switch (attn->mask_type) { + case MASK_NONE: + case MASK_2D_DUMMY: + break; // No mask + case MASK_2D_KEY_PADDING: + sizes = {attn->batch_size, 1, attn->total_sequence_length}; + strides = Get2DMaskStrides(attn->total_sequence_length); + break; + case MASK_3D_ATTENTION: + sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; + strides = {attn->sequence_length * attn->total_sequence_length, attn->total_sequence_length, 1}; + break; + case MASK_4D_MEGATRON: + // offset to skip past sequence part, so that we can index it with [batch_index, sequence_index, token_index] + offseted_buffer = buffer + attn->past_sequence_length * attn->max_sequence_length; + sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; + strides = {attn->max_sequence_length * attn->max_sequence_length, attn->max_sequence_length, 1}; + break; + default: + LOGS_DEFAULT(FATAL) << "unsupported mask type: " << attn->mask_type; + throw std::runtime_error("unsupported mask type"); + } + return {offseted_buffer, sizes, strides}; +} + +template +struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { + std::string Signature() const override { + return MakeString( + "B", attention->batch_size, + "_S", attention->sequence_length, + "_T", attention->total_sequence_length, + "_N", attention->num_heads, + "_H", attention->head_size, + "_Hv", attention->v_head_size, + bias_buffer != nullptr ? "_B" : "_NB", + "_M", mask_index_dims.size(), + "_QKV", attention->qkv_format, + "_MODE", attention->mode); + } + + std::tuple GetGemmsMNKOBatch() const { + ORT_ENFORCE(attention != nullptr); + auto m = attention->sequence_length; + auto n = attention->total_sequence_length; + auto k = attention->head_size; + auto o = attention->v_head_size; + auto batch = attention->batch_size * attention->num_heads; + return {m, n, k, o, batch}; + } + + hipblasHandle_t handle; + const RocmAttentionParameters* attention; + const hipDeviceProp_t* device_prop; + + float scale; + const T* q_buffer; + const T* k_buffer; + const T* v_buffer; + T* out_buffer; + + // optional, attention bias [B,N,S,T] + // TODO: support shape [B,1,S,T], [1, N, S, T], [1, 1, S, T] with broadcast. + const T* bias_buffer{nullptr}; + + // optional, mask value + const int* mask_index_buffer{nullptr}; + TensorShapeVector mask_index_dims{}; + + // optional, internal + void* workspace_buffer{nullptr}; +}; + +inline bool IsKVBNMH(AttentionMode mode) { + switch (mode) { + case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: + case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + return true; + default: + return false; + } +} + +template +struct GemmSoftmaxGemmPermuteGenericPipeline { + static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams* params) { + return params->mask_index_buffer != nullptr && params->mask_index_dims.size() >= 2; + } + + static std::tuple GetWorkspacePlan(const GemmSoftmaxGemmPermuteParams* params) { + auto bytes = GetAttentionScratchSize( + sizeof(T), + params->attention->batch_size, + params->attention->num_heads, + params->attention->sequence_length, + params->attention->total_sequence_length); + auto gemm1_out = reinterpret_cast(params->workspace_buffer); + auto softmax_out = gemm1_out + (bytes / sizeof(T)); + auto gemm2_out = softmax_out + (bytes / sizeof(T)); + return {gemm1_out, softmax_out, gemm2_out}; + } + + inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { + return GetAttentionWorkspaceSize( + sizeof(T), + attn->batch_size, + attn->num_heads, + attn->head_size, + attn->sequence_length, + attn->total_sequence_length); + } + + inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams* params) { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + + int k_buffer_stride = n * k; + if (IsKVBNMH(params->attention->mode)) { + k_buffer_stride = params->attention->max_sequence_length * params->attention->head_size; + } + + // GEMM1 [m,k] * [n,k]' -> [m,n] + return blas::row_major::StridedBatchedGemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::Trans, + m, n, k, + // For raw attention mask, the scalar is moved to softmax computation. + /*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale, + params->q_buffer, k, m * k, + params->k_buffer, k, k_buffer_stride, + /*beta=*/0.0f, + gemm1_out, n, m * n, + batch); + } + + inline static Status SoftmaxRawMask(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { + // Softmax on [m,n] along the n dimension. + // Raw attention mask could be 2D (B,S) or 3D (B,S,T) or 4D(B,1,M,M), where M is the max sequence length. + auto attn = params->attention; + auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + T* persistent_softmax_workspace = gemm1_out; // replace Q*K' in place if persistent softmax is selected. + return ComputeSoftmaxWithRawMask( + params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, + strides, buffer, nullptr, params->bias_buffer, gemm1_out, softmax_out, + attn->is_unidirectional, /* FIXME: this must not be attn.scale! */ params->scale, + use_persistent_softmax, persistent_softmax_workspace, attn->mask_filter_value); + } + + inline static Status Softmax1DIndexMask(const GemmSoftmaxGemmPermuteParams* params) { + auto mask_1d = params->mask_index_buffer; + auto mask_1d_size = params->mask_index_dims[0]; + // Softmax on [m,n] along the n dimension. + // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. + auto attn = params->attention; + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + const int* mask_start = (mask_1d_size > attn->batch_size) ? mask_1d + attn->batch_size : nullptr; + return ComputeSoftmaxWithMask1D( + params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, + mask_1d, mask_start, params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); + } + + inline static Status SoftmaxNoMask(const GemmSoftmaxGemmPermuteParams* params) { + // Softmax on [m,n] along the n dimension. + auto attn = params->attention; + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + return ComputeSoftmax( + params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, + params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); + } + + inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams* params) { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + + int v_buffer_stride = n * o; + if (IsKVBNMH(params->attention->mode)) { + v_buffer_stride = params->attention->max_sequence_length * params->attention->v_head_size; + } + + // GEMM2 [m,n] * [n,o] -> [m,o] + // semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H. + return blas::row_major::StridedBatchedGemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + m, o, n, + /*alpha=*/1.0f, + softmax_out, n, m * n, + params->v_buffer, o, v_buffer_stride, + /*beta=*/0.0f, + gemm2_out, o, m * o, + batch); + } + + inline static Status Permute0213(const GemmSoftmaxGemmPermuteParams* params) { + // Permute 0213 + // gemm2_out is B,N,S,H, transpose to out_buffer as B,S,N,H + auto attn = params->attention; + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + return LaunchTransCtx( + params->StreamHandle(), + attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, + params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer); + } + + static Status GetSupportedStatus(const GemmSoftmaxGemmPermuteParams* params) { + const auto& attn = params->attention; + // TODO: address the BNMH k,v strides + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: + case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: + case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: + case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: + if (attn->qkv_format == Q_K_V_BNSH) { + return Status::OK(); + } else { + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, got ", + attn->qkv_format); + } + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH but k, v are BLNH"); + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + // If sequence_length is 1, query of B1NH can be simply viewed as BN1H. + if (attn->sequence_length == 1) { + return Status::OK(); + } else { + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, ", + "only if sequence_length is 1, query of BSNH can be viewed as BNSH"); + } + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH"); + default: + return TUNABLE_OP_UNSUPPORTED("unknonw"); + } + return TUNABLE_OP_UNSUPPORTED("unknonw case"); + } + + static Status Run(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { + auto supported_status = GetSupportedStatus(params); + if (!supported_status.IsOK()) { + return supported_status; + } + ORT_RETURN_IF_ERROR(Gemm1(params)); + + if (UseRawAttentionMask(params)) { + ORT_RETURN_IF_ERROR(SoftmaxRawMask(params, use_persistent_softmax)); + } else if (params->mask_index_dims.size() == 1) { // 1d index mask + ORT_RETURN_IF_ERROR(Softmax1DIndexMask(params)); + } else { + ORT_RETURN_IF_ERROR(SoftmaxNoMask(params)); + } + + ORT_RETURN_IF_ERROR(Gemm2(params)); + ORT_RETURN_IF_ERROR(Permute0213(params)); + return Status::OK(); + } +}; + +template +class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp> { + public: + GemmSoftmaxGemmPermuteTunableOp(); + + inline static bool IsSupportedMode(const RocmAttentionParameters* attn) { + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: + // depends on qkv format + if (attn->qkv_format == Q_K_V_BNSH || attn->qkv_format == Q_K_V_BSNH) { + return true; + } else { + return false; + } + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: + return true; + default: + return false; + } + } + + inline static bool IsSupportedMaskType(const RocmAttentionParameters* attn) { + switch (attn->mask_type) { + case MASK_NONE: + case MASK_2D_DUMMY: + case MASK_2D_KEY_PADDING: + case MASK_3D_ATTENTION: + case MASK_4D_MEGATRON: + return true; + default: + return false; + } + } + + inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { + size_t num_bytes = GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(attn); + +#ifdef USE_COMPOSABLE_KERNEL + if (IsSupportedMaskType(attn)) { + auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn); + num_bytes = std::max(num_bytes, sizeof(T) * sizes.x * sizes.y * sizes.z); + } +#endif + + return num_bytes; + } + + template + __global__ static void ConvertToFilledMaskValue( + T* __restrict__ out, + const int3 out_strides, + const int* __restrict__ mask_buffer, + const int3 mask_lengths, // [B,S,T] + const int3 mask_strides, + Converter cvt) { + const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; + if (global_idx >= mask_lengths.x * mask_lengths.y * CeilDiv(mask_lengths.z, VecSize)) { + return; + } + + const int tidx = (global_idx % CeilDiv(mask_lengths.z, VecSize)) * VecSize; + const int bs_idx = global_idx / CeilDiv(mask_lengths.z, VecSize); + const int sidx = bs_idx % mask_lengths.y; + const int bidx = bs_idx / mask_lengths.y; + + int64_t in_offset = mask_strides.x * bidx + mask_strides.y * sidx + mask_strides.z * tidx; + int64_t out_offset = out_strides.x * bidx + out_strides.y * sidx + out_strides.z * tidx; + + if (tidx + VecSize <= mask_lengths.z) { + using LoadT = const aligned_vector; + using StoreT = aligned_vector; + LoadT load = *reinterpret_cast(mask_buffer + in_offset); + StoreT store; + +#pragma unroll + for (int i = 0; i < VecSize; i++) { + store.val[i] = cvt(load.val[i]); + } + *reinterpret_cast(out + out_offset) = store; + } else { +#pragma unroll + for (int i = 0; i < mask_lengths.z - tidx; i++) { + out[out_offset + i] = cvt(mask_buffer[in_offset + i]); + } + } + } + + static Status LaunchConvertToFilledMaskValue(const GemmSoftmaxGemmPermuteParams* params) { + constexpr const int kThreadPerBlock = 256; + constexpr const int kVecSize = 4; + + auto attn = params->attention; + auto [buffer, lengths, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); + int64_t total_threads = lengths.x * lengths.y * CeilDiv(lengths.z, kVecSize); + auto num_blocks = CeilDiv(total_threads, kThreadPerBlock); + + auto mask_filter_value = attn->mask_filter_value; + auto cvt = [=] __device__(int v) -> T { + return v == 1 ? 0 : mask_filter_value; + }; + + ConvertToFilledMaskValue<<StreamHandle()>>>( + reinterpret_cast(params->workspace_buffer), {lengths.y * lengths.z, lengths.z, 1}, // out desc + buffer, lengths, strides, // mask desc + cvt); + + return HIP_CALL(hipGetLastError()); + } +}; + +#ifdef USE_COMPOSABLE_KERNEL + +template +auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { + constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); + + using Nop = ck::tensor_operation::element_wise::PassThrough; + using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), + "attention mode is not supported, got ", params->attention->mode); + if constexpr (USE_BIAS) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer == nullptr, "biased version only support input with bias"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer != nullptr, "non-biased version only support input without bias"); + } + if constexpr (USE_MASK) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), + "mask type is not supported, got ", params->attention->mask_type); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer == nullptr, "masked version only support input with mask"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); + } + + auto attn = params->attention; + const int& G0 = attn->batch_size; + const int& G1 = attn->num_heads; + const int& M = attn->sequence_length; + const int& N = attn->total_sequence_length; + const int& K = attn->head_size; + const int& O = attn->v_head_size; + { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); + } + + auto [qs, ks, vs] = GetQkvStrides(attn); + std::vector q_buffer_lengths = {G0, G1, M, K}; + std::vector q_buffer_strides = qs.template ForBNSHCoord>(); + std::vector k_buffer_lengths = {G0, G1, N, K}; + std::vector k_buffer_strides = ks.template ForBNSHCoord>(); + std::vector v_buffer_lengths = {G0, G1, O, N}; + std::vector v_buffer_strides = vs.template ForBNHSCoord>(); + std::vector out_buffer_lengths = {G0, G1, M, O}; + std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 + + std::array bias_buffers{}; + std::array, kNumBiasBuffer> bias_lengths{}; + std::array, kNumBiasBuffer> bias_strides{}; + if constexpr (USE_BIAS) { + bias_buffers[0] = const_cast(params->bias_buffer); + bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + bias_strides[0] = {G1 * M * N, M * N, N, 1}; + } + if constexpr (USE_MASK) { + bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; + bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + if (params->mask_index_dims.size() == 2) { // [B,T] + bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; + } else if (params->mask_index_dims.size() == 3) { // [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else { + ORT_ENFORCE(false, "Unreachable"); + } + } + + auto arg = impl->MakeArgumentPointer( + params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, + bias_buffers, // Gemm1 bias, as attention mask + {}, // Gemm2 bias + q_buffer_lengths, q_buffer_strides, + k_buffer_lengths, k_buffer_strides, + v_buffer_lengths, v_buffer_strides, + out_buffer_lengths, out_buffer_strides, + bias_lengths, bias_strides, + {}, + {}, + Nop{}, + Nop{}, + Acc0ElementOp{params->scale}, + Nop{}, + Nop{}); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + + if constexpr (USE_MASK) { + ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); + } + + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); +} + +template +auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { + using CKDataType = typename CKDataTypeAdaptor::type; + using D0DataType = typename ck::detail::tuple_concat< + std::conditional_t, ck::Tuple<>>, + std::conditional_t, ck::Tuple<>>>::type; + + constexpr static auto MaskingSpecMaskDisabled = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + constexpr static auto MaskingSpecMaskOutUpperTriangle = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + + std::vector>>> + ret; + + for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { + auto type_string = impl->GetTypeString(); + + auto invoker = impl->MakeInvokerPointer(); + auto op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GemmSoftmaxGemmPermuteParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); + + return GetArgAndRunInvoker(impl, invoker, params); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); + } + + for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { + auto type_string = impl->GetTypeString(); + + auto invoker = impl->MakeInvokerPointer(); + auto op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GemmSoftmaxGemmPermuteParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->attention->sequence_length != params->attention->total_sequence_length, + "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); + + return GetArgAndRunInvoker(impl, invoker, params); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); + } + + return ret; +} +#endif // USE_COMPOSABLE_KERNEL + +template +GemmSoftmaxGemmPermuteTunableOp::GemmSoftmaxGemmPermuteTunableOp() { + this->RegisterOp([](const GemmSoftmaxGemmPermuteParams* params) { + return GemmSoftmaxGemmPermuteGenericPipeline::Run(params, false); + }); + +#ifdef USE_COMPOSABLE_KERNEL + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } +#endif +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h new file mode 100644 index 0000000000000..0aff519d20e99 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +Status LaunchDecoderAttentionKernel( + const hipDeviceProp_t& prop, // Device Properties + RocmTuningContext* tuning_ctx, // context for tuning + Stream* stream, // ORT Stream + hipblasHandle_t& hipblas, // hipblas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden layer size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise.h b/onnxruntime/contrib_ops/rocm/bert/elementwise.h new file mode 100644 index 0000000000000..768295767835a --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +Status LaunchElementwiseKernel(RocmTuningContext* tuning_ctx, Stream* stream, + const T* input, int input_length, + const T* bias, int bias_length, + T* output); + +// The following is LaunchElementwiseKernel implementation detail. Their interfaces are exposed for kernel explorer. +namespace internal { + +template +struct ElementwiseParams : OpParams { + ElementwiseParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, + const T* input, const T* bias, T* output, int input_length, int bias_length) + : OpParams(tuning_ctx, stream), + input(input), + bias(bias), + output(output), + input_length(input_length), + bias_length(bias_length) {} + + std::string Signature() const override { + std::string sig = std::to_string(input_length) + "_" + std::to_string(bias_length); + return sig; + } + + const T* input; + const T* bias; + T* output; + int input_length; + int bias_length; +}; + +template +class ElementwiseOp { + public: + Status operator()(const ElementwiseParams* params); + Status IsSupported(const ElementwiseParams* params); +}; + +template +Status ElementwiseStaticSelection(const ElementwiseParams* params); + +template +class ElementwiseTunableOp : public TunableOp> { + public: + ElementwiseTunableOp(); +}; + +} // namespace internal + +#define ELEMENTWISE_FWD_DECL(FnName, T) \ + namespace functor { \ + struct FnName; \ + } + +ELEMENTWISE_FWD_DECL(FastGeLU, float); +ELEMENTWISE_FWD_DECL(FastGeLU, double); +ELEMENTWISE_FWD_DECL(FastGeLU, half); +ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16); + +ELEMENTWISE_FWD_DECL(GeLU, float); +ELEMENTWISE_FWD_DECL(GeLU, double); +ELEMENTWISE_FWD_DECL(GeLU, half); +ELEMENTWISE_FWD_DECL(GeLU, BFloat16); + +ELEMENTWISE_FWD_DECL(ReLU, float); +ELEMENTWISE_FWD_DECL(ReLU, half); +ELEMENTWISE_FWD_DECL(ReLU, BFloat16); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh new file mode 100644 index 0000000000000..8255e70d27e48 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh @@ -0,0 +1,256 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/tunable/util.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "contrib_ops/rocm/bert/elementwise.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +namespace functor { + +struct FastGeLU { + template + __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { + constexpr const float b = 0.7978845608028654f; // sqrt(2.0/M_PI) + + // const T cdf = a + a * _Tanh(in * (c * in * in + b)); + const T xb = x * T(b); + const T u = xb * T(0.044715f) * x * x + xb; + const T emu = __expf(-u - u); + const T cdf = T(1.0f) / (T(1.0f) + emu); + y = x * cdf; + } +}; + +struct GeLU { + template + __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { + y = T(0.5f) * x * (T(1.f) + T(erf(0.70710678118f * float(x)))); + } +}; + +struct ReLU { + template + __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { + y = x >= T{} ? x : T{}; + } +}; + +} // namespace functor + +using onnxruntime::rocm::CeilDiv; +using onnxruntime::rocm::GPU_WARP_SIZE; + +template +__global__ void ElementwiseKernel( + const T* __restrict__ input, int input_length, + const T* __restrict__ bias, int bias_length, + T* __restrict__ output) { + const int idx = blockIdx.x * TPB + threadIdx.x; + Fn f{}; + + if (idx < input_length) { + const T x = input[idx] + (bias == nullptr ? T{} : bias[idx % bias_length]); + f(output[idx], x); + } +} + +template +__global__ void ElementwiseKernelVec( + const T* __restrict__ input, int input_length, + const T* __restrict__ bias, int bias_length, + T* output) { + using VecT = onnxruntime::rocm::aligned_vector; + Fn f{}; + + const int idx = (blockIdx.x * TPB + threadIdx.x) * ILP; + if (idx < input_length) { + T input_v[ILP]; + VecT* input_val = reinterpret_cast(&input_v); + *input_val = *reinterpret_cast(&input[idx]); + T output_v[ILP]; + VecT* output_val = reinterpret_cast(&output_v); + T bias_v[ILP]; + if (bias != nullptr) { + VecT* bias_val = reinterpret_cast(&bias_v); + *bias_val = *reinterpret_cast(&bias[idx % bias_length]); + } + +#pragma unroll + for (int i = 0; i < ILP; i++) { + const T x = (bias == nullptr) ? input_v[i] : (T)(input_v[i] + bias_v[i]); + f(output_v[i], x); + } + *(reinterpret_cast(&output[idx])) = *output_val; + } +} + +template +Status LaunchElementwiseKernel( + RocmTuningContext* tuning_ctx, Stream* stream, + const T* input, int input_length, + const T* bias, int bias_length, + T* output) { + internal::ElementwiseParams params(tuning_ctx, stream, input, bias, output, input_length, bias_length); + if (tuning_ctx->IsTunableOpEnabled()) { + static internal::ElementwiseTunableOp op; + return op(¶ms); + } + + return internal::ElementwiseStaticSelection(¶ms); +} + +namespace internal { + +template +Status ElementwiseOp::operator()(const ElementwiseParams* params) { + dim3 blocks(CeilDiv(params->input_length, ThreadsPerBlock * VecSize)); + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, + params->bias, params->bias_length, + params->output); + return HIP_CALL(hipGetLastError()); +} + +template +Status ElementwiseOp::IsSupported(const ElementwiseParams* params) { + // TODO(anyone): Add tail handling for FastGelu + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) || + (params->bias_length == 0 && params->input_length % VecSize == 0))); + // Avoid redundant configurations + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->input_length > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)); + + return Status::OK(); +} + +template +Status ElementwiseStaticSelection(const ElementwiseParams* params) { + constexpr int block_size = 256; + if constexpr (std::is_same_v) { + if (params->bias != nullptr) { + if (0 == (params->bias_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 + const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else if (0 == (params->bias_length % 4)) { + const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else if (0 == (params->bias_length % 2)) { + const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else { + const int grid_size = (params->input_length + block_size - 1) / block_size; + ElementwiseKernel<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } + } else { + if (0 == (params->input_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 + const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else if (0 == (params->input_length % 4)) { + const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else if (0 == (params->input_length % 2)) { + const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else { + const int grid_size = (params->input_length + block_size - 1) / block_size; + ElementwiseKernel<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } + } + } else { + const int grid_size = (params->input_length + block_size - 1) / block_size; + ElementwiseKernel<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } + return HIP_CALL(hipGetLastError()); +} + +template +ElementwiseTunableOp::ElementwiseTunableOp() { + this->RegisterOp(ElementwiseStaticSelection); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); +} + +#undef ADD_OP + +} // namespace internal + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime + +#define ELEMENTWISE_KERNEL_IMPL(Fn, T) \ + namespace onnxruntime { \ + namespace contrib { \ + namespace rocm { \ + template Status LaunchElementwiseKernel( \ + RocmTuningContext * tuning_ctx, Stream* stream, \ + const T* input, int input_length, \ + const T* bias, int bias_length, \ + T* output); \ + namespace internal { \ + template class ElementwiseTunableOp; \ + } \ + } \ + } \ + } diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu new file mode 100644 index 0000000000000..c2a670ea76aca --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" + +ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float); +ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double); +ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half); +ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu new file mode 100644 index 0000000000000..97f0f74640c6e --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" + +ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double); +ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float); +ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half); +ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu new file mode 100644 index 0000000000000..67e50869133f5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" + +ELEMENTWISE_KERNEL_IMPL(functor::ReLU, float); +ELEMENTWISE_KERNEL_IMPL(functor::ReLU, half); +ELEMENTWISE_KERNEL_IMPL(functor::ReLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc new file mode 100644 index 0000000000000..fdb62d3a2aec5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/gemm_fast_gelu.h" + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/rocm/rocm_common.h" + +using onnxruntime::rocm::ToHipType; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GemmFastGelu, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + GemmFastGelu); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +template +Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { + typedef typename ToHipType::MappedType HipT; + + const auto* X = ctx->Input(0); + const auto* W = ctx->Input(1); + const auto* bias = ctx->Input(2); + + bool transa = false; + bool transb = false; + bool trans_batch_a = false; + bool trans_batch_b = false; + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(X->Shape(), W->Shape(), transa, transb, trans_batch_a, trans_batch_b, false)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) + return Status::OK(); + + // gemmfastgelu only support alpha == 1 and beta == 0 + const HipT alpha = ToHipType::FromFloat(1.0f); + const HipT beta = ToHipType::FromFloat(0.0f); + + using onnxruntime::rocm::tunable::blas::BlasOp; + + return blas::row_major::GemmFastGelu( + GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), + transa ? BlasOp::Trans : BlasOp::NonTrans, + transb ? BlasOp::Trans : BlasOp::NonTrans, + helper.M(), helper.N(), helper.K(), + alpha, + reinterpret_cast(X->Data()), helper.Lda(transa), + reinterpret_cast(W->Data()), helper.Ldb(transb), + (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, + beta, + reinterpret_cast(Y->MutableData()), helper.Ldc()); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h new file mode 100644 index 0000000000000..ae4f84fa5f033 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using onnxruntime::rocm::RocmKernel; + +template +class GemmFastGelu final : public RocmKernel { + public: + GemmFastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {} + Status ComputeInternal(OpKernelContext* ctx) const override; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh new file mode 100644 index 0000000000000..77f53f9eed027 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#ifdef USE_COMPOSABLE_KERNEL +#include "core/providers/rocm/composable_kernel_common.h" + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#endif + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" + +using onnxruntime::rocm::ToHipType; + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { +namespace internal { + +#ifdef USE_COMPOSABLE_KERNEL + +using onnxruntime::rocm::CKBlasOpAdaptor; +using onnxruntime::rocm::CKDataTypeAdaptor; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Nop = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +template +auto GetCKGemmAddFastGeluTypeStringAndOps() { + using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; + using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, BLayout, ck::Tuple, Row, + CKDataType, CKDataType, ck::Tuple, CKDataType, + Nop, Nop, AddFastGelu>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = onnxruntime::MakeString("withbias ", impl->GetTypeString()); + + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { + auto one = ToHipType::FromFloat(1.0f); + auto zero = ToHipType::FromFloat(0.0f); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->alpha != one || params->beta != zero || params->bias == nullptr, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); + + auto nop = Nop{}; + auto addfastgelu = AddFastGelu{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, std::array{params->bias}, params->c, + params->m, params->n, params->k, + params->lda, params->ldb, std::array{0}, params->ldc, + nop, nop, addfastgelu); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); + } + return ret; +} + +template +auto GetCKGemmFastGeluTypeStringAndOps() { + using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; + using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, BLayout, ck::Tuple<>, Row, + CKDataType, CKDataType, ck::Tuple<>, CKDataType, + Nop, Nop, FastGelu>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = onnxruntime::MakeString("nobias ", impl->GetTypeString()); + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { + auto one = ToHipType::FromFloat(1.0f); + auto zero = ToHipType::FromFloat(0.0f); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->alpha != one || params->beta != zero || params->bias != nullptr, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); + + auto nop = Nop{}; + auto fastgelu = FastGelu{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, + {}, + params->c, + params->m, params->n, params->k, + params->lda, params->ldb, + {}, + params->ldc, + nop, nop, fastgelu); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); + } + return ret; +} +#else +struct Row {}; +struct Col {}; +#endif // USE_COMPOSABLE_KERNEL + +} // namespace internal +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h new file mode 100644 index 0000000000000..2b8a21b83f177 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +using onnxruntime::rocm::tunable::blas::BlasOp; +using onnxruntime::rocm::tunable::blas::BlasOpToString; + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { + +template +struct GemmFastGeluParams : OpParams { + std::string Signature() const override { + bool has_bias = (nullptr != bias) ? 0 : 1; + return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias); + } + hipblasHandle_t handle; + BlasOp opa; + BlasOp opb; + int64_t m; + int64_t n; + int64_t k; + T alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + const T* bias; + T beta; + T* c; + int64_t ldc; +}; + +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu new file mode 100644 index 0000000000000..8d7e64b1015be --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES +#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" + +#include +#include + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" +#include "core/providers/rocm/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { + +namespace row_major { + +template +inline GEMMFASTGELU(T, ScalarT) { + GemmFastGeluParams params; + params.tuning_ctx = tuning_ctx; + params.stream = stream; + params.handle = handle; + + params.opa = opa; + params.opb = opb; + params.m = m; + params.n = n; + params.k = k; + if constexpr (!std::is_same_v && std::is_same_v) { + params.alpha = ToHipType::FromFloat(std::forward(alpha)); + } else { + params.alpha = alpha; + } + params.a = a; + params.lda = lda; + params.b = b; + params.ldb = ldb; + params.bias = bias; + if constexpr (!std::is_same_v && std::is_same_v) { + params.beta = ToHipType::FromFloat(std::forward(beta)); + } else { + params.beta = beta; + } + params.c = c; + params.ldc = ldc; + + if (tuning_ctx->IsTunableOpEnabled()) { + if (opa == BlasOp::N && opb == BlasOp::N) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + return gemm_fast_gelu(¶ms); + } else if (opa == BlasOp::T && opb == BlasOp::N) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + return gemm_fast_gelu(¶ms); + } else if (opa == BlasOp::N && opb == BlasOp::T) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + return gemm_fast_gelu(¶ms); + } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + return gemm_fast_gelu(¶ms); + } + } + + return internal::GemmFastGeluUnfused(¶ms); +} + +#define CALL_GEMMFASTGELU(T, ScalarT) \ + GemmFastGelu(tuning_ctx, stream, handle, \ + opa, opb, \ + m, n, k, \ + alpha, a, lda, b, ldb, bias, \ + beta, c, ldc) + +// clang-format off +GEMMFASTGELU(float, float ) { return CALL_GEMMFASTGELU(float, float ); } +GEMMFASTGELU(half, half ) { return CALL_GEMMFASTGELU(half, half ); } +GEMMFASTGELU(BFloat16, BFloat16) { return CALL_GEMMFASTGELU(BFloat16, BFloat16); } +GEMMFASTGELU(half, float ) { return CALL_GEMMFASTGELU(half, float ); } +GEMMFASTGELU(BFloat16, float ) { return CALL_GEMMFASTGELU(BFloat16, float ); } +// clang-format on + +#undef CALL_GEMMFASTGELU + +} // namespace row_major + +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h new file mode 100644 index 0000000000000..b707c63ef44be --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "core/common/status.h" +#include "core/common/float16.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { + +#define GEMMFASTGELU(T, ScalarT) \ + common::Status GemmFastGelu( \ + RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ + BlasOp opa, BlasOp opb, \ + std::int64_t m, std::int64_t n, std::int64_t k, \ + ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ + const T* bias, ScalarT beta, T* c, std::int64_t ldc) + +namespace row_major { + +GEMMFASTGELU(float, float); +GEMMFASTGELU(half, half); +GEMMFASTGELU(BFloat16, BFloat16); +GEMMFASTGELU(half, float); +GEMMFASTGELU(BFloat16, float); + +} // namespace row_major + +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime + +#ifndef _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES +#undef GEMMFASTGELU +#endif diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh new file mode 100644 index 0000000000000..e157aa57f8c43 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "contrib_ops/rocm/bert/elementwise.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/gemm_hipblaslt.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { +namespace internal { + +using namespace onnxruntime::rocm::tunable::blas::internal; + +template +Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { + namespace column_major = onnxruntime::rocm::tunable::blas::column_major; + ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning_ctx, params->stream, params->handle, + params->opb, params->opa, + params->n, params->m, params->k, + params->alpha, params->b, params->ldb, params->a, params->lda, + params->beta, params->c, params->ldc)); + + int64_t fast_gelu_input_length = params->m * params->n; + int64_t bias_length = (params->bias != nullptr) ? params->n : 0; + + // Because of GemmFastGeluUnfused is a combination of GemmOp and FastGeluOp, FastGeluOp in this combination is + // an inplace computation. + // 1. If we call GemmFastGeluUnfused directly with enabled tuning, it may cause the input buffer of FastGelu been + // updated accumulatedly and result in incorrect result finally. This only happens if the tuning's FindFastest is invoked. + // 2. It's safe to call GemmFastGeluUnfused with disabled tuning, FastGelu only run once and produce correct result. + // 3. It's safe to call GemmFastGeluUnfused as part of GemmFastGeluTunableOp with enable tuning, GemmTunableOp and + // FastGeluTunableOp will do tune in first warmup step separately during GemmFastGeluUnfused profiling process. + // After that, the call to GemmFastGeluUnfused not invoke tuning's FindFastest of FastGelu. + // + // Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp + // to protect original input value. + return onnxruntime::contrib::rocm::LaunchElementwiseKernel( + params->tuning_ctx, params->Stream(), + params->c, static_cast(fast_gelu_input_length), + params->bias, static_cast(bias_length), + params->c); +} + +template +class GemmFastGeluTunableOp : public TunableOp> { + public: + GemmFastGeluTunableOp() { + this->RegisterOp(GemmFastGeluUnfused); +#ifdef USE_COMPOSABLE_KERNEL + for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif + +#ifdef USE_HIPBLASLT + for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif + } +}; + +} // namespace internal +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu new file mode 100644 index 0000000000000..09a6550549614 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -0,0 +1,530 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/rocm/bert/group_query_attention.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" +#include "contrib_ops/rocm/bert/rotary_embedding_impl.h" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" + +#ifdef USE_COMPOSABLE_KERNEL_CK_TILE +#include "ck_tile/core/numeric/integer.hpp" +#include "fmha_fwd.hpp" +#endif + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ + .MayInplace(3, 1) \ + .MayInplace(4, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 6), \ + GroupQueryAttention); + +// REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +// REGISTER_KERNEL_TYPED(BFloat16) + +template +std::string GetCkFmhaDataTypeString(); + +template <> +std::string GetCkFmhaDataTypeString() { + return "fp16"; +} + +template <> +std::string GetCkFmhaDataTypeString() { + return "bf16"; +} + +__global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int inc) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = seqlens[idx] + inc; + } +} + +Status LaunchSeqlensInc(hipStream_t stream, const int* seqlens, int* out, int num_elems, int inc) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqlens_inc_kernel<<>>(seqlens, out, num_elems, inc); + return HIP_CALL(hipGetLastError()); +} + +__global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = idx * length_per_seq; + } + if (idx == 0) { + out[num_elems] = num_elems * length_per_seq; + } +} + +Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int length_per_seq) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqstart_init_kernel<<>>(out, num_elems, length_per_seq); + return HIP_CALL(hipGetLastError()); +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, + const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + if (s < seqlens_k[b] + 1) { + position_ids[tid] = s; + } else { + position_ids[tid] = 1; + } + } +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + position_ids[tid] = seqlens_k[tid]; + } +} + +// Convert seqlens_k to position_ids +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, + int64_t* position_ids, hipStream_t stream, const int max_threads_per_block) { + const int seqlen = parameters.sequence_length; + const int batch_size = parameters.batch_size; + const int threads = max_threads_per_block; + const int blocks = (batch_size * seqlen + threads - 1) / threads; + if (parameters.is_first_prompt) { + SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else { + SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); + } + return HIP_CALL(hipGetLastError()); +} + +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : RocmKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_past_bsnh_ = false; + is_unidirectional_ = true; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + scale_ = info.GetAttrOrDefault("scale", 0.0f); +} + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template +Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { +#if USE_COMPOSABLE_KERNEL_CK_TILE + auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); + const Tensor* query = ctx->Input(0); + const Tensor* key = ctx->Input(1); + const Tensor* value = ctx->Input(2); + const Tensor* past_key = ctx->Input(3); + const Tensor* past_value = ctx->Input(4); + const Tensor* seqlens_k = ctx->Input(5); + const Tensor* total_seqlen = ctx->Input(6); + const Tensor* cos_cache = ctx->Input(7); + const Tensor* sin_cache = ctx->Input(8); + + auto& device_prop = GetDeviceProp(); + std::call_once( + arch_checking_, + [](const hipDeviceProp_t& device_prop) { + if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && + std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " + << "CDNA2 and CDNA3 archs."; + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention running on an unsuppoted GPU may result in " + << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; + } + }, + device_prop); + + GroupQueryAttentionParameters parameters; + using HipT = typename ToHipType::MappedType; + + const int max_thr_per_blk = device_prop.maxThreadsPerBlock; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶meters, + num_heads_, + kv_num_heads_, + seqlens_k, + total_seqlen, + is_past_bsnh_, + scale_, + max_thr_per_blk)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + // parameters.zeros_count = kZerosCount; + // parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context->OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + Tensor* output = ctx->Output(0, output_shape); + Strides output_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + + int4 past_shape; + std::vector present_dims; + Strides present_strides; + Strides past_strides; + if (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + past_shape = { + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size}; + past_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size); + present_dims = { + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size}; + present_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + } else { // BNSH + past_shape = { + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size}; + past_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size); + present_dims = { + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size}; + present_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size); + } + TensorShape present_shape(present_dims); + Tensor* present_key = ctx->Output(1, present_shape); + Tensor* present_value = ctx->Output(2, present_shape); + + Strides query_strides; + Strides key_strides; + Strides value_strides; + int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; // BNSH coord + const HipT* query_ptr = reinterpret_cast(query->DataRaw()); + const HipT* key_ptr; + const HipT* value_ptr; + if (!parameters.is_packed_qkv) { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); + value_strides = key_strides; + key_ptr = reinterpret_cast(key->DataRaw()); + value_ptr = reinterpret_cast(value->DataRaw()); + } else { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + value_strides = query_strides; + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key_ptr = query_ptr + key_offset; + value_ptr = key_ptr + value_offset; + } + + IAllocatorUniquePtr rotary_q_tmp; + IAllocatorUniquePtr rotary_k_tmp; + if (parameters.do_rotary) { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); + + rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); + rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); + auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, + reinterpret_cast(seqlens_k->DataRaw()), + reinterpret_cast(rotary_position_ids_tmp.get()), + hip_stream, max_thr_per_blk)); + // Launch rotary embedding kernel + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_q_tmp.get(), query_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, + query_strides.ForBNSHCoord(), + rotary_q_strides.ForBNSHCoord())); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.kv_num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, + key_strides.ForBNSHCoord(), + rotary_k_strides.ForBNSHCoord())); + query_ptr = reinterpret_cast(rotary_q_tmp.get()); + key_ptr = reinterpret_cast(rotary_k_tmp.get()); + query_strides = rotary_q_strides; + key_strides = rotary_k_strides; + } + + const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; + IAllocatorUniquePtr seqlens_k_tmp; + + // build present kv cache + auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); + auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); + if (parameters.is_first_prompt) { + // copy prompt kv to present kv + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); + const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); + parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: + if (!parameters.kv_share_buffer) { + // copy past to present, + // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are + // not the same, aka, can not be as simple as strided + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_key_ptr, past_shape, past_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + // In the case of share buffer + ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); + ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); + } + // then append new kv to present + size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + + // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. + // we should call fmha with total sequence lengths + seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); + seqlens_k_ptr = seqlens_k_tmp.get(); + } + static_assert(std::is_same_v); + + const float scale = parameters.scale == 0.0f + ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + bias_enum bias_type = bias_enum::no_bias; + + mask_info mask = [&]() { + if (local_window_size_ != -1) { + mask_info ret; + ret.type = mask_enum::window_generic; + ret.left = local_window_size_; + ret.right = parameters.is_unidirectional ? 0 : -1; + // ret.x = kv_sequence_length - (sequence_length - ret.left); + // ret.y = sequence_length + (ret.right - kv_sequence_length); + return ret; + } + + if (parameters.is_first_prompt && is_unidirectional_) { + return mask_info::decode("t", sequence_length, kv_sequence_length); + } + + return mask_info::decode("0", sequence_length, kv_sequence_length); + }(); + + auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_q_tmp.get(), batch_size, + query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_k_tmp.get(), batch_size, + present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); + + fmha_fwd_args args{ + query_ptr, + present_key->DataRaw(), + present_value->DataRaw(), + nullptr, // bias, alibi/element + nullptr, // lse, logsumexp buffer + output->MutableDataRaw(), + seqstart_q_tmp.get(), // seqstart_q_ptr, for group mode + seqstart_k_tmp.get(), // seqstart_k_ptr, for group mode + seqlens_k_ptr, // seqlen_k_ptr, for group mode + sequence_length, // seqlen_q, for batch mode + kv_sequence_length, // seqlen_k, for batch mode + parameters.batch_size, // batch + parameters.sequence_length, // max_seqlen_q + parameters.head_size, // hdim_q + parameters.head_size, // hdim_v + parameters.num_heads, + parameters.kv_num_heads, + scale, + 1.0f, // scale_p of squant, useless + 1.0f, // scale_o of squant, useless + static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S + batch_size, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 + static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S + static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N + 0, // nhead_stride_bias + batch_size, // nhead_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B + static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B + 0, // batch_stride_bias + num_heads * batch_size, // batch_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B + mask.left, // window_size_left + mask.right, // window_size_right + static_cast(mask.type)}; + +#if 0 + std::cout + << "\n sequence_length:" << sequence_length + << "\n kv_sequence_length:" << kv_sequence_length + << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache + << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; + + std::cout + << "\n q_ptr:" << args.q_ptr + << "\n k_ptr:" << args.k_ptr + << "\n v_ptr:" << args.v_ptr + << "\n bias_ptr:" << args.bias_ptr + << "\n lse_ptr:" << args.lse_ptr + << "\n o_ptr:" << args.o_ptr + << "\n seqstart_q_ptr:" << args.seqstart_q_ptr + << "\n seqstart_k_ptr:" << args.seqstart_k_ptr + << "\n seqlen_k_ptr:" << args.seqlen_k_ptr + << "\n seqlen_q:" << args.seqlen_q + << "\n seqlen_k:" << args.seqlen_k + << "\n batch:" << args.batch + << "\n max_seqlen_q:" << args.max_seqlen_q + << "\n hdim_q:" << args.hdim_q + << "\n hdim_v:" << args.hdim_v + << "\n nhead_q:" << args.nhead_q + << "\n nhead_k:" << args.nhead_k + << "\n scale_s:" << args.scale_s + << "\n scale_p:" << args.scale_p + << "\n scale_o:" << args.scale_o + << "\n stride_q:" << args.stride_q + << "\n stride_k:" << args.stride_k + << "\n stride_v:" << args.stride_v + << "\n stride_bias:" << args.stride_bias + << "\n stride_o:" << args.stride_o + << "\n nhead_stride_q:" << args.nhead_stride_q + << "\n nhead_stride_k:" << args.nhead_stride_k + << "\n nhead_stride_v:" << args.nhead_stride_v + << "\n nhead_stride_bias:" << args.nhead_stride_bias + << "\n nhead_stride_lse:" << args.nhead_stride_lse + << "\n nhead_stride_o:" << args.nhead_stride_o + << "\n batch_stride_q:" << args.batch_stride_q + << "\n batch_stride_k:" << args.batch_stride_k + << "\n batch_stride_v:" << args.batch_stride_v + << "\n batch_stride_bias:" << args.batch_stride_bias + << "\n batch_stride_lse:" << args.batch_stride_lse + << "\n batch_stride_o:" << args.batch_stride_o + << "\n window_size_left:" << args.window_size_left + << "\n window_size_right:" << args.window_size_right + << "\n mask_type:" << args.mask_type + << std::endl; +#endif + + fmha_fwd_traits traits{ + parameters.head_size, + parameters.head_size, // v head size + GetCkFmhaDataTypeString(), + !parameters.is_first_prompt, // true, // is_group_mode + true, // is_v_rowmajor ? dim is fastest : seq is fastest + mask.type, + bias_type, + false, // has_lse + false, // do_fp8_static_quant, aka, squant + }; + + ck_tile::stream_config stream_config{ + hip_stream, + false // time_kernel + }; + + auto duration = fmha_fwd(traits, args, stream_config); + if (duration < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "fmha_fwd internal error"); + } + HIP_RETURN_IF_ERROR(hipGetLastError()); + + return Status::OK(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tile to be enabled"); +#endif +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h new file mode 100644 index 0000000000000..ce0de1f761aa5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class GroupQueryAttention final : public RocmKernel { + public: + GroupQueryAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool is_unidirectional_; + bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; + float scale_; + + private: + static std::once_flag arch_checking_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh new file mode 100644 index 0000000000000..2eeb7c3e8f279 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh @@ -0,0 +1,270 @@ +#include "hip/hip_runtime.h" +/* + The implementation of this file is based on bert plugins in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/shared_inc/rocm_call.h" + +using namespace onnxruntime::rocm; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +__device__ inline T Rsqrt(const T& x); + +template <> +__device__ inline float Rsqrt(const float& x) { + return rsqrtf(x); +} + +template <> +__device__ inline half Rsqrt(const half& x) { +#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) + return hrsqrt(x); +#else + return half(rsqrtf(static_cast(x))); +#endif +} + +__device__ inline half2 AddHalf2(const half2 a, const half2 b) { +#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) + return __hadd2(a, b); +#else + return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y)); +#endif +} + +struct KeyValuePairSum { + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, + const hipcub::KeyValuePair& b) { + return hipcub::KeyValuePair(a.key + b.key, a.value + b.value); + } + + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, + const hipcub::KeyValuePair& b) { + const half2 a2 = __halves2half2(a.key, a.value); + const half2 b2 = __halves2half2(b.key, b.value); + const half2 res = AddHalf2(a2, b2); + return hipcub::KeyValuePair(__low2half(res), __high2half(res)); + } + + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, + const hipcub::KeyValuePair& b) { + return hipcub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); + } +}; + +template +__device__ inline void LayerNorm( + const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, + const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + + using BlockReduce = hipcub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U mu; // mean + __shared__ U rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); + } + __syncthreads(); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const U val = static_cast(output[idx]); + const U g = static_cast(gamma[i]); + const U b = (nullptr == beta) ? U(0.f) : static_cast(beta[i]); + output[idx] = static_cast(g * (val - mu) * rsigma + b); + } +} + +template +__device__ inline void SimplifiedLayerNorm( + const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U rsigma; // 1 / std.dev. + + const U sum = BlockReduce(temp_storage).Sum(thread_data); + + if (threadIdx.x == 0) { + rsigma = Rsqrt(sum + epsilon); + } + __syncthreads(); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const U val = static_cast(output[idx]); + const U g = static_cast(gamma[i]); + output[idx] = static_cast(g * val * rsigma); + } +} + +template +__device__ inline void SimplifiedLayerNormVec( + const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + using VecV = aligned_vector; + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U rsigma; // 1 / std.dev. + + const U sum = BlockReduce(temp_storage).Sum(thread_data); + + if (threadIdx.x == 0) { + rsigma = Rsqrt(sum + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { + int idx = offset + i; + const VecV gamma_v = *reinterpret_cast(gamma + i); + VecV output_v = *reinterpret_cast(output + idx); + +#pragma unroll + for (int k = 0; k < ILP; k++) { + output_v.val[k] = U(gamma_v.val[k]) * U(output_v.val[k]) * rsigma; + } + *(reinterpret_cast(output + idx)) = output_v; + } + } +} + +template +__device__ inline void LayerNormVec( + const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, + const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + using VecV = aligned_vector; + using BlockReduce = hipcub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U mu; // mean + __shared__ U rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { + int idx = offset + i; + const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + i) : VecV(); + const VecV gamma_v = *reinterpret_cast(gamma + i); + VecV output_v = *reinterpret_cast(output + idx); + +#pragma unroll + for (int k = 0; k < ILP; k++) { + output_v.val[k] = (beta != nullptr) ? U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma + U(beta_v.val[k]) : U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma; + } + *(reinterpret_cast(output + idx)) = output_v; + } + } +} + +template +__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair& thread_data, + const int ld, const int idx, const V* beta, const V* gamma, + const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + // Small settings: the block covers the leading dimension TPB >= ld. The input + // value is available in a register + using VecV = aligned_vector; + using BlockReduce = hipcub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U mu; // mean + __shared__ U rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const hipcub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + threadIdx.x * ILP) : VecV(); + const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); + VecV output_v; + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = (beta != nullptr) ? U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma + U(beta_v.val[i]) : U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma; + } + *(reinterpret_cast(output + idx)) = output_v; + } +} + +template +__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const U& thread_data, const int ld, const int idx, + const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + // Small settings: the block covers the leading dimension TPB >= ld. The input + // value is available in a register + using VecV = aligned_vector; + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U rsigma; // 1 / std.dev. + + const U sum = BlockReduce(temp_storage).Sum(thread_data); + + if (threadIdx.x == 0) { + rsigma = Rsqrt(sum + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); + VecV output_v; + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = U(gamma_v.val[i]) * U(input_v[i]) * rsigma; + } + *(reinterpret_cast(output + idx)) = output_v; + } +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu new file mode 100644 index 0000000000000..5d4ef53b8ba97 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/multihead_attention.h" + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" +#include "core/platform/env_var_utils.h" +#include "core/providers/rocm/rocm_common.h" + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_MHA_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MultiHeadAttention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + MultiHeadAttention) + +REGISTER_MHA_KERNEL_TYPED(float); +REGISTER_MHA_KERNEL_TYPED(MLFloat16); + +static constexpr int kPastSequenceLengthInputIndex = 7; +static constexpr int kBeamWidthInputIndex = 8; +static constexpr int kPastInputIndex = 5; +static constexpr int kPresentOutputIndex = 1; + +#define REGISTER_DMMHA_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DecoderMaskedMultiHeadAttention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .MayInplace(kPastInputIndex + 1, kPresentOutputIndex + 1) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \ + .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \ + MultiHeadAttention) + +REGISTER_DMMHA_KERNEL_TYPED(float); +REGISTER_DMMHA_KERNEL_TYPED(MLFloat16); + +template +MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) + : RocmKernel(info), + attn_type_(info.node().OpType() == "DecoderMaskedMultiHeadAttention" ? kDecoderMaskedMultiHeadAttention + : kMultiHeadAttention) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + + mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + + past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + + using HipT = typename ToHipType::MappedType; + using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; + tunable_op_ = std::make_shared(); +} + +template +Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE( + GetTuningContext()->IsTunableOpEnabled(), + "MultiHeadAttention of ROCm EP is only supported if tunable op is used and tuning is enabled."); + + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + + const Tensor* bias{}; + const Tensor* key_padding_mask{}; + const Tensor* attention_bias{}; + const Tensor* past_key{}; + const Tensor* past_value{}; + const Tensor* past_seq_len{}; + + const Tensor* cache_indirection = nullptr; + + if (attn_type_ == kMultiHeadAttention) { + bias = context->Input(3); + key_padding_mask = context->Input(4); + attention_bias = context->Input(5); + past_key = context->Input(6); + past_value = context->Input(7); + } else if (attn_type_ == kDecoderMaskedMultiHeadAttention) { + key_padding_mask = context->Input(3); + attention_bias = context->Input(4); + past_key = context->Input(5); + past_value = context->Input(6); + past_seq_len = context->Input(kPastSequenceLengthInputIndex); + // const Tensor* beam_width = context->Input(8); // NOTE: not used + // const Tensor* cache_indirection = context->Input(9); // TODO: should not present for ROCm EP + bias = context->Input(10); + } + + if (nullptr != bias) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "qkv_bias is not supported on ROCm EP. " + "User should fuse the qkv bias to qkv projection instead."); + } + + auto& device_prop = GetDeviceProp(); + RocmAttentionParameters attn; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, + key, + value, + bias, + key_padding_mask, + attention_bias, + past_key, + past_value, + cache_indirection, + past_seq_len, + &attn, /* parameters */ + num_heads_, + mask_filter_value_, + scale_, + is_unidirectional_, + past_present_share_buffer_, + attn_type_, + device_prop.maxThreadsPerBlock)); + + if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(attn.batch_size); + output_shape[1] = static_cast(attn.sequence_length); + output_shape[2] = static_cast(attn.v_hidden_size); + Tensor* output = context->Output(0, output_shape); + + std::vector present_dims{ + attn.batch_size, + attn.num_heads, + past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, + attn.head_size, + }; + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(1, present_shape); + Tensor* present_value = context->Output(2, present_shape); + + ORT_RETURN_IF_ERROR(ClassifyAttentionMode( + attn_type_, &attn, + /*qkv=*/{query, key, value}, + /*past=*/{past_key, past_value}, + /*present=*/{present_key, present_value})); + + using HipT = typename ToHipType::MappedType; + using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; + auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn); + auto workspace = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + hipStream_t stream = Stream(context); + if (nullptr != present_key) { // process past present concat + Strides dst_strides; + + int4 past_shape; + Strides past_src_strides; + const HipT* past_key_src; + const HipT* past_value_src; + HipT* past_key_dst{}; + HipT* past_value_dst{}; + + int4 add_shape; + Strides add_src_strides; + const HipT* add_key_src = reinterpret_cast(key->DataRaw()); + const HipT* add_value_src = reinterpret_cast(value->DataRaw()); + HipT* add_key_dst; + HipT* add_value_dst; + + if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH || + attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { + dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + + past_shape = {attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; + past_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); + past_key_src = reinterpret_cast(past_key->DataRaw()); + past_value_src = reinterpret_cast(past_value->DataRaw()); + past_key_dst = reinterpret_cast(present_key->MutableDataRaw()); + past_value_dst = reinterpret_cast(present_value->MutableDataRaw()); + + if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH) { + add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); + } else if (attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { + add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); + } + } else if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH || + attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { + dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + + if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH) { + add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); + } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { + add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); + } + } else if ( + attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || + attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || + attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH || + attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { + dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); + + if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH) { + add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); + } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { + add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "past present concatenation is not implemented for attention mode ", attn.mode); + } + add_shape = {attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size}; // kernel in coord (b,n,s,h) + add_key_dst = reinterpret_cast(present_key->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + add_value_dst = reinterpret_cast(present_value->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + + if (past_key_dst) { + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, past_key_src, past_shape, past_src_strides.ForBNSHCoord(), + past_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } + if (past_value_dst) { + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, past_value_src, past_shape, past_src_strides.ForBNSHCoord(), + past_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } + + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, add_key_src, add_shape, add_src_strides.ForBNSHCoord(), + add_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, add_value_src, add_shape, add_src_strides.ForBNSHCoord(), + add_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } + + GemmSoftmaxGemmPermuteParams params; + params.tuning_ctx = GetTuningContext(); + params.stream = context->GetComputeStream(); + params.handle = GetHipblasHandle(context); + params.attention = &attn; + params.device_prop = &device_prop; + params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_; + std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews( + &attn, + nullptr == query ? nullptr : reinterpret_cast(query->DataRaw()), + nullptr == key ? nullptr : reinterpret_cast(key->DataRaw()), + nullptr == value ? nullptr : reinterpret_cast(value->DataRaw()), + nullptr == present_key ? nullptr : reinterpret_cast(present_key->DataRaw()), + nullptr == present_value ? nullptr : reinterpret_cast(present_value->DataRaw())); + params.out_buffer = reinterpret_cast(output->MutableDataRaw()); + + if (key_padding_mask != nullptr) { + params.mask_index_buffer = key_padding_mask->Data(); + params.mask_index_dims = key_padding_mask->Shape().AsShapeVector(); + } + + if (attention_bias != nullptr) { + params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); + } + + params.workspace_buffer = reinterpret_cast(workspace.get()); + return (*std::static_pointer_cast(tunable_op_))(¶ms); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h new file mode 100644 index 0000000000000..1d676d7a7bcac --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/rocm/rocm_kernel.h" +#include "contrib_ops/rocm/bert/attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class MultiHeadAttention final : public RocmKernel { + public: + MultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + AttentionType attn_type_; + int num_heads_; // number of attention heads + float mask_filter_value_; + float scale_; + bool past_present_share_buffer_{false}; + bool is_unidirectional_{false}; + + // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: + // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. + // 2. We don't want to construct the object repeatly (which is expansive) during Compute. + std::shared_ptr tunable_op_; +}; + +template +class DecoderMaskedMultiHeadAttention final : public RocmKernel { + public: + DecoderMaskedMultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + AttentionType mha_type; + int num_heads_; // number of attention heads + float mask_filter_value_; + float scale_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc new file mode 100644 index 0000000000000..9e649fb591896 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/skip_layer_norm.h" + +#include "core/providers/rocm/rocm_common.h" +#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" +#include "contrib_ops/rocm/bert/transformer_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipSimplifiedLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); + ORT_ENFORCE(epsilon_ >= 0); +} + +template +Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input = ctx->Input(0); + const Tensor* skip = ctx->Input(1); + const Tensor* gamma = ctx->Input(2); + + const Tensor* beta = Simplified ? nullptr : ctx->Input(3); + const Tensor* bias = Simplified ? ctx->Input(3) : ctx->Input(4); + + Tensor* output = ctx->Output(0, input->Shape()); + + // For inferencing, we support one more optional output which is the sum + // of the input and skip tensors + Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); + + if (input->Shape() != skip->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip is expected to have same shape as input"); + } + + if (input->Shape().Size() == 0) { + return Status::OK(); + } + + const auto& input_dims = input->Shape().GetDims(); + size_t input_dims_size = input_dims.size(); + if (input_dims_size != 3 && input_dims_size != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 or 2 dimensions, got ", input_dims_size); + } + + int hidden_size = static_cast(input_dims[input_dims_size - 1]); + + const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "gamma is expected to have 1 dimension, got ", gamma_dims.size()); + } + if (gamma_dims[0] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of gamma and input does not match"); + } + + if (nullptr != beta) { + const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "beta is expected to have 1 dimension, got ", beta_dims.size()); + } + if (beta_dims[0] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of beta and input does not match"); + } + } + + if (nullptr != bias) { + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of bias and input does not match"); + } + } + + int64_t element_count = input->Shape().Size(); + typedef typename ToHipType::MappedType HipT; + + return LaunchSkipLayerNormKernel( + GetTuningContext(), + ctx->GetComputeStream(), + reinterpret_cast(output->MutableData()), + skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, + reinterpret_cast(input->Data()), + reinterpret_cast(skip->Data()), + reinterpret_cast(gamma->Data()), + (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, + epsilon_, + hidden_size, + static_cast(element_count)); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h new file mode 100644 index 0000000000000..02228bc59cedc --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class SkipLayerNorm final : public RocmKernel { + public: + SkipLayerNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + float epsilon_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu new file mode 100644 index 0000000000000..8387c49a3310b --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu @@ -0,0 +1,86 @@ +#include "hip/hip_runtime.h" +/* + The implementation of this file is based on skipLayerNorm plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Modifications: Add SkipLayerNormKernelVec to +// leverage vectorized load/write. +// and templatize ComputeSkipLayerNorm for different +// data types. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" + +#include + +#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" +#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, + const T* skip, const V* gamma, const V* beta, const T* bias, float epsilon, int ld, int element_count) { + // this must be true because element_count is the total size of the tensor + assert(element_count % ld == 0); + + SkipLayerNormParams params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, + gamma, beta, bias, epsilon, ld, element_count); + + if (tuning_ctx->IsTunableOpEnabled()) { + static SkipLayerNormTunableOp op; + return op(¶ms); + } + + return SkipLayerNormStaticSelection(¶ms); +} + +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, + const float* skip, const float* gamma, const float* beta, + const float* bias, float epsilon, int ld, + int element_count); + +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, + const half* skip, const half* gamma, const half* beta, + const half* bias, float epsilon, int ld, + int element_count); + +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, + const float* skip, const float* gamma, const float* beta, + const float* bias, float epsilon, int ld, + int element_count); + +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, + const half* skip, const half* gamma, const half* beta, + const half* bias, float epsilon, int ld, + int element_count); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h new file mode 100644 index 0000000000000..5e2a92447d2f5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning, + Stream* stream, + V* output, // output tensor + T* skip_input_bias_add_output, // optional output tensor + const T* input, // input tensor + const T* skip, // skip tensor + const V* gamma, // Layer normalization gamma tensor + const V* beta, // Layer normalization beta tensor + const T* bias, // Layer normalization beta tensor + float epsilon, // Layer normalization epsilon + int hidden_size, // hidden size, it is the leading dimension (ld) + int element_count // number of elements in input tensor +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h new file mode 100644 index 0000000000000..fcfbc8969e498 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "contrib_ops/rocm/bert/layer_norm.cuh" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +T maybe2half(float x); + +template <> +float maybe2half(float x) { + return x; +} + +template <> +half maybe2half(float x) { + return __float2half_rn(x); +} + +template +__global__ void SkipLayerNormKernel( + const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias, + const U epsilon, V* output, T* skip_input_bias_add_output) { + const U reverse_ld = U(1.f / ld); + const int offset = blockIdx.x * ld; + + KeyValuePairSum pair_sum; + // reduce x and x^2 + hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const U val = (bias == nullptr) ? static_cast(input[idx]) + static_cast(skip[idx]) : static_cast(input[idx]) + static_cast(skip[idx]) + static_cast(bias[i]); + const U rldval = reverse_ld * val; + thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); + + if (skip_input_bias_add_output != nullptr) { + skip_input_bias_add_output[idx] = static_cast(val); + } + + output[idx] = static_cast(val); + } + + if constexpr (Simplified) { + SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); + return; + } + + LayerNorm(thread_data, ld, offset, beta, gamma, epsilon, output); +} + +// Vectorized kernel +template +__global__ void SkipLayerNormKernelVec( + const int ld, const T* input, const T* skip, const V* beta, const V* gamma, + const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, + bool hasBias, bool hasSkipInputBiasAdditionOutput) { + const U reverse_ld = U(1.f / ld); + const int offset = blockIdx.x * ld; + + KeyValuePairSum pair_sum; + // reduce x and x^2 + hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); + + using VecT = aligned_vector; + using VecV = aligned_vector; + if (threadIdx.x * ILP < ld) { + for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { + int idx = offset + i; + + const VecT input_v = *reinterpret_cast(input + idx); + const VecT skip_v = *reinterpret_cast(skip + idx); + const VecT bias_v = hasBias ? *reinterpret_cast(bias + i) : VecT(); + VecT skip_input_bias_add_output_v, output_v; + +#pragma unroll + for (int k = 0; k < ILP; k++) { + const U val = hasBias ? static_cast(input_v.val[k]) + static_cast(skip_v.val[k]) + static_cast(bias_v.val[k]) : static_cast(input_v.val[k]) + static_cast(skip_v.val[k]); + const U rldval = reverse_ld * val; + + if (hasSkipInputBiasAdditionOutput) { + skip_input_bias_add_output_v.val[k] = static_cast(val); + } + thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); + output_v.val[k] = static_cast(val); + } + + if (hasSkipInputBiasAdditionOutput) { + *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; + } + + *(reinterpret_cast(output + idx)) = output_v; + } + } + + if constexpr (Simplified) { + SimplifiedLayerNormVec(thread_data.value, ld, offset, gamma, epsilon, output); + return; + } + + LayerNormVec(thread_data, ld, offset, beta, gamma, epsilon, output); +} + +// Vectorized kernel +template +__global__ void SkipLayerNormKernelSmall( + const int ld, const T* input, const T* skip, const V* beta, const V* gamma, + const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, + bool hasBias, bool hasSkipInputBiasAdditionOutput) { + const U rld = U(1.f / ld); + const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld + + using VecT = aligned_vector; + hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); + + VecT input_v; + if (ILP * threadIdx.x < ld) { + input_v = *reinterpret_cast(input + idx); + const VecT skip_v = *reinterpret_cast(skip + idx); + const VecT bias_v = hasBias ? *reinterpret_cast(bias + threadIdx.x * ILP) : VecT(); + VecT skip_input_bias_add_output_v; + + U rldval_sum = U(0.f); + U rldvalsq_sum = U(0.f); +#pragma unroll + for (int i = 0; i < ILP; i++) { + const U val = hasBias ? static_cast(input_v.val[i]) + static_cast(skip_v.val[i]) + static_cast(bias_v.val[i]) : static_cast(input_v.val[i]) + static_cast(skip_v.val[i]); + + if (hasSkipInputBiasAdditionOutput) { + skip_input_bias_add_output_v.val[i] = static_cast(val); + } + + const U rldval = rld * val; + rldval_sum += rldval; + rldvalsq_sum += rldval * val; + input_v.val[i] = static_cast(val); + } + + if (hasSkipInputBiasAdditionOutput) { + *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; + } + + thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); + } + + if constexpr (Simplified) { + SimplifiedLayerNormSmall(input_v.val, thread_data.value, ld, idx, gamma, epsilon, output); + return; + } + + LayerNormSmall(input_v.val, thread_data, ld, idx, beta, gamma, epsilon, output); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h new file mode 100644 index 0000000000000..0391704ce1c56 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +using onnxruntime::rocm::CeilDiv; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +struct SkipLayerNormParams : OpParams { + SkipLayerNormParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, + const T* skip, const V* gamma, const V* beta, + const T* bias, float epsilon, int ld, int element_count) + : OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {} + + std::string Signature() const override { + std::string sig = std::to_string(ld) + "_" + std::to_string(element_count); + return sig; + } + + V* output; + T* skip_input_bias_add_output; + const T* input; + const T* skip; + const V* gamma; + const V* beta; + const T* bias; + float epsilon; + int ld; + int element_count; +}; + +template +Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { + // Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels, + // which could offer better performance in some combinations of ThreadsPerBlock and VecSize. + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !((params->ld <= 8192 && params->ld % VecSize == 0 && + params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))); + SkipLayerNormKernelSmall<<element_count, params->ld)), + dim3(ThreadsPerBlock), + 0, params->StreamHandle()>>>( + params->ld, params->input, params->skip, + params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, + (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); + return HIP_CALL(hipGetLastError()); +} + +template +Status SkipLayerNormRegularOp(const SkipLayerNormParams* params) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !((params->ld > 0 && params->ld % VecSize == 0 && + (params->ld >= ThreadsPerBlock * VecSize || + (params->ld < GPU_WARP_SIZE && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))))); + SkipLayerNormKernelVec<<element_count, params->ld)), + dim3(ThreadsPerBlock), + 0, params->StreamHandle()>>>( + params->ld, params->input, params->skip, + params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, + (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); + return HIP_CALL(hipGetLastError()); +} + +template +Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { + bool hasBias = (params->bias == nullptr) ? false : true; + bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true; + const int grid_size = params->element_count / params->ld; + const int block_size = 256; + +#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \ + if (params->ld <= ELEMENTS) { \ + SkipLayerNormKernelSmall<<StreamHandle()>>>( \ + params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \ + static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, \ + hasBias, hasSkipInputBiasAdditionOutput); \ + break; \ + } + if (0 == (params->ld % 4)) { + do { + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 32, 2) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 32, 4) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 96, 4) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(768, 192, 4) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(1024, 256, 4) + + SkipLayerNormKernel<<StreamHandle()>>>( + params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, + static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); + } while (0); + } else { + do { + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 64, 1) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 128, 1) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 384, 1) + + SkipLayerNormKernel<<StreamHandle()>>>( + params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, + static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); + } while (0); + } + return HIP_CALL(hipPeekAtLastError()); +} // namespace rocm + +#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ + this->RegisterOp(name); \ + this->RegisterOp(name); \ + this->RegisterOp(name); \ + this->RegisterOp(name); \ + this->RegisterOp(name); + +#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 512) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 576) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 640) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 704) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 768) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 832) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 1024) + +template +class SkipLayerNormTunableOp : public TunableOp> { + public: + SkipLayerNormTunableOp() { + this->RegisterOp(SkipLayerNormStaticSelection); + ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmallOp) + ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegularOp) + + // NOTE: the 1st kernel is SkipLayerNorm Original implementation. + this->SetDefaultId(0); + } +}; + +#undef ADD_OP_FOR_ALL_VEC_SIZE +#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc b/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc new file mode 100644 index 0000000000000..6ae8d1202d462 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include "core/providers/shared_library/provider_api.h" // Include this otherwise Windows build complains Env::Default() missing +#include "core/platform/env_var_utils.h" +#include "contrib_ops/rocm/bert/transformer_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +// The environment variable is for testing purpose only, and it might be removed in the future. +// If you need some option in production, please file a feature request. +constexpr const char* kTransformerOptions = "ORT_TRANSFORMER_OPTIONS"; + +// Initialize the singleton instance +TransformerOptions TransformerOptions::instance; + +const TransformerOptions* TransformerOptions::GetInstance() { + if (!instance.initialized_) { + // We do not use critical section here since it is fine to initialize multiple times by different threads. + int value = ParseEnvironmentVariableWithDefault(kTransformerOptions, 0); + instance.Initialize(value); + + if (value > 0) + std::cout << "ORT_TRANSFORMER_OPTIONS: IsPrecisionMode=" << instance.IsPrecisionMode() + << ",DisablePersistentSoftmax=" << instance.DisablePersistentSoftmax() + << ",DisableHalf2=" << instance.DisableHalf2() + << std::endl; + } + + return &instance; +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.h b/onnxruntime/contrib_ops/rocm/bert/transformer_common.h new file mode 100644 index 0000000000000..6816b5b9d07ec --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/transformer_common.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/rocm_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +class TransformerOptions { + public: + static const TransformerOptions* GetInstance(); + + bool IsPrecisionMode() const { return is_precision_mode_; } + + bool DisablePersistentSoftmax() const { return disable_persistent_softmax_; } + + bool DisableHalf2() const { return disable_half2_; } + + void Initialize(int value) { + is_precision_mode_ = (value & 0x01) > 0; + disable_persistent_softmax_ = (value & 0x02) > 0; + disable_half2_ = (value & 0x04) > 0; + initialized_ = true; + } + + private: + // Default is false. If the mode is on, prefer precision than speed. + bool is_precision_mode_{false}; + + // Disable persistent softmax. + bool disable_persistent_softmax_{false}; + + // Disable half2 kernel. + bool disable_half2_{false}; + + bool initialized_{false}; + + static TransformerOptions instance; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh new file mode 100644 index 0000000000000..d0a0d09fcbae3 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#ifdef USE_COMPOSABLE_KERNEL +#include "core/providers/rocm/composable_kernel_common.h" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" +#endif // USE_COMPOSABLE_KERNEL + +#include "contrib_ops/rocm/diffusion/group_norm_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#ifdef USE_COMPOSABLE_KERNEL + +using onnxruntime::rocm::CKDataTypeAdaptor; + +// The SiLU function is a special case of Swish function, +// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: +// SiLU(x) = x * sigmoid(x) +// Swish(x) = x * sigmoid(bx) +// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. +using Silu = ck::tensor_operation::element_wise::Swish; +using Pass = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +template +auto GetCKGroupNormNHWCTypeStringAndOps() { + using XDataType = typename CKDataTypeAdaptor::type; + using YDataType = typename CKDataTypeAdaptor::type; + using SaveMeanInvStdDataType = typename CKDataTypeAdaptor::type; + using GammaDataType = float; + using BetaDataType = float; + + using Activation = std::conditional_t; + + std::vector>>> ret; + for (auto&& impl : internal::GetDeviceGroupNormInstances()) { + std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; + auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; + auto invoker = impl->MakeInvokerPointer(); + + auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by composable kernel."); + if constexpr (WithSilu) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !params->use_silu, "Silu version only support groupnorm with silu"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->use_silu, "Pass version only support groupnorm without silu"); + } + std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; + std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, + params->c, params->channels_per_group, 1}; + std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; + std::vector reduce_dims{1, 2, 4}; + + auto activation = Activation{}; + + auto arg = impl->MakeArgumentPointer(in_lengths, // lengths + in_out_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + in_out_strides, // yStrides + {0, 0}, // saveMeanStrides + {0, 0}, // saveInvStdStrides + reduce_dims, // reduceDims + params->epsilon, + params->src, + params->gamma, + params->beta, + params->dst, + nullptr, + nullptr, + activation); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_group_norm_op))); + } + return ret; +} +#endif // USE_COMPOSABLE_KERNEL + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh new file mode 100644 index 0000000000000..68f7d47282845 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifdef USE_COMPOSABLE_KERNEL +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" +#include "ck/utility/data_type.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using F16 = ck::half_t; +using F32 = float; + +using Silu = ck::tensor_operation::element_wise::Swish; +using Pass = ck::tensor_operation::element_wise::PassThrough; + +using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface +using ck::tensor_operation::device::DeviceNormalizationFwdImpl; // the implementation + +// See https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1fefd82ed8/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp + +template +using device_normalization_f32_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl + // clang-format on + >; + +template +using device_normalization_f16_instances = + // clang-format off + std::tuple < + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl + // clang-format on + >; + +// Use this function to get implementation +template +std::vector>> +GetDeviceGroupNormInstances() { + return {}; +} + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F16, F32, F32, F16, F32, Silu, 5, 3>(); + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F16, F32, F32, F16, F32, Pass, 5, 3>(); + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F32, F32, F32, F32, F32, Silu, 5, 3>(); + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F32, F32, F32, F32, F32, Pass, 5, 3>(); + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu new file mode 100644 index 0000000000000..ad191314e5e4c --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f16_instances{}); + + return instances; +} + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f16_instances{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu new file mode 100644 index 0000000000000..ceb53ed442abc --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f32_instances{}); + + return instances; +} + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f32_instances{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h new file mode 100644 index 0000000000000..7cff640db2f34 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { + GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, + onnxruntime::Stream* ort_stream, + T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + float* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) + : OpParams(tuning_ctx, ort_stream), + GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, + num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} + + std::string Signature() const override { + std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; + std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; + std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; + std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; + std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + + std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + + skip_suffix + broadcast_suffix + bias_suffix; + return sig; + } +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu new file mode 100644 index 0000000000000..142aaf14e8d2d --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// The ROCM kernel is hipified from CUDA kernel. +#include "contrib_ops/rocm/diffusion/group_norm_impl.h" + +#include +#include "contrib_ops/rocm/diffusion/group_norm_common.h" +#include "contrib_ops/rocm/diffusion/group_norm_tunable_op.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +Status LaunchGroupNormKernel( + RocmTuningContext* tuning_ctx, + Stream* ort_stream, + T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, + reinterpret_cast(workspace), epsilon, batch_size, num_channels, + height, width, num_groups, use_silu, broadcast_skip, channels_per_block); + + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in ROCM does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); + } + + HIP_RETURN_IF_ERROR(hipMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); + + if (tuning_ctx->IsTunableOpEnabled()) { + static GroupNormNHWCTunableOp op; + return op(¶ms); + } + + return GroupNormNHWCStaticSelection(¶ms); +} + +template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, + half* add_out, const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); + +template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, + float* add_out, const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh new file mode 100644 index 0000000000000..c6ca16bfdfc80 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "contrib_ops/rocm/diffusion/group_norm_common.h" +#include "core/providers/rocm/triton_kernel.h" + +using namespace onnxruntime::rocm; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#ifdef USE_TRITON_KERNEL + +namespace { + +template +std::string GetGroupNormTritonGroupName() { + std::string ret = "GroupNormTriton_"; + std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; + ret += silu_suffix; + ret += GetDataTypeName(); + return ret; +} + +} // namespace + +template +auto GetTritonGroupNormNHWCTypeStringAndOps() { + std::vector>>> ret; + auto group_name = GetGroupNormTritonGroupName(); + auto* kernel_list = GetOrtTritonKernelByGroup(group_name); + if (kernel_list == nullptr) { + return ret; + } + + for (auto i : *kernel_list) { + // Check params match + auto* metadata = GetOrtTritonKernelMetadata(i); + auto block_size = metadata->constants.at("BLOCK_SIZE"); + auto hw_size = metadata->constants.at("HW_SIZE"); + auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, + "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", + params->channels_per_group, ")."); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); + if constexpr (WithSilu) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); + } + // Construct args for launch kernel + struct { + const void* src; + const void* skip; + const void* bias; + void* out; + void* add_out; + const void* gamma; + const void* beta; + int hw; + int c; + int c_per_group; + float eps; + bool has_skip; + bool has_bias; + bool broadcast_skip; + } args = { + (const void*)params->src, + (const void*)params->skip, + (const void*)params->bias, + (void*)params->dst, + (void*)params->skip_workspace, + (const void*)params->gamma, + (const void*)params->beta, + params->hw, + params->c, + params->channels_per_group, + params->epsilon, + params->skip != nullptr, + params->bias != nullptr, + params->broadcast_skip, + }; + + // Grid dim is (batch_count, groups, 1) + return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); + }; + ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); + } + return ret; +} + +#endif // USE_TRITON_KERNEL + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py new file mode 100644 index 0000000000000..5ba96ebc117f0 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -0,0 +1,135 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from itertools import product + +import triton +import triton.language as tl + + +@triton.jit +def group_norm_kernel( + input_ptr, + skip_ptr, + bias_ptr, + output_ptr, + add_out_ptr, + gamma_ptr, + beta_ptr, + img_size, + c, + c_per_group, + eps, + has_skip, + has_bias, + broadcast_skip, + BLOCK_SIZE: tl.constexpr, + HW_SIZE: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, +): + row_x = tl.program_id(0) + row_y = tl.program_id(1) + stride = img_size * c + input_ptr += row_x * stride + row_y * c_per_group + output_ptr += row_x * stride + row_y * c_per_group + gamma_ptr += row_y * c_per_group + beta_ptr += row_y * c_per_group + + cols = tl.arange(0, BLOCK_SIZE) + hw = tl.arange(0, HW_SIZE) + offsets = hw[:, None] * c + cols[None, :] + mask = (cols < c_per_group)[None, :] + + bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if has_skip: + add_out_ptr += row_x * stride + row_y * c_per_group + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group + if has_bias: + bias_ptr += row_y * c_per_group + bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + + # Calculate mean and variance + _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) + _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) + for i in range(tl.cdiv(img_size, HW_SIZE)): + x_ptr = input_ptr + i * HW_SIZE * c + a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip and not broadcast_skip: + s_ptr = skip_ptr + i * HW_SIZE * c + s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + a += s + if has_bias or broadcast_skip: + a += bias + _sum += a + _square_sum += a * a + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + tl.store(add_y_ptr + offsets, a, mask=mask) + + # Set axis=None (or leave it unspecified) to reduce all axes. + # TODO: In older Triton we have to reduce an axis at a time, but in our case + # for some configs it may have some issue when reducing sequentially along the axes. + group_mean = tl.sum(_sum, axis=None) / (img_size * c_per_group) + group_var = tl.sum(_square_sum, axis=None) / (img_size * c_per_group) - group_mean * group_mean + + rstd = 1 / tl.sqrt(group_var + eps) + + # Normalize and apply linear transformation + gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) + beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) + for i in range(tl.cdiv(img_size, HW_SIZE)): + y_ptr = output_ptr + i * HW_SIZE * c + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + else: + x_ptr = input_ptr + i * HW_SIZE * c + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + x_hat = (x - group_mean) * rstd + y = x_hat * gamma + beta + if ACTIVATION_SILU: + y *= tl.sigmoid(y) + tl.store(y_ptr + offsets, y, mask=mask) + + +# We can have more combinations of blocks and hw_sizes, e.g., +# blocks = [16, 32, 64, 128, 256, 512] +# hw_sizes = [8, 16, 32, 64, 128, 256, 512] +# but this will result in too many functions and slow down the compilation. +with_silu = [True, False] +dtypes = ["fp32", "fp16"] +blocks = [16, 32, 64, 128] +hw_sizes = [8, 16, 32, 64, 128, 256] +warps = [1, 2, 4, 8, 16] +name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" +group_pattern = "GroupNormTriton_{}_{}" + + +def get_function_table(): + func_table = [] + + for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): + silu_suffix = "Silu" if silu else "Pass" + name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) + group = group_pattern.format(silu_suffix, dtype) + sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) + kwargs = { + "num_warps": warp, + "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, + } + func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} + func_table.append(func_desc) + return func_table + + +if __name__ == "__main__": + func_table = get_function_table() + for func_desc in func_table: + print(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h new file mode 100644 index 0000000000000..e6831f764b418 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/rocm_common.h" +#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh" +#include "contrib_ops/rocm/diffusion/group_norm_common.h" +#include "contrib_ops/rocm/diffusion/group_norm_impl.h" +#include "contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh" +#include "contrib_ops/rocm/diffusion/group_norm_triton.cuh" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using onnxruntime::rocm::GPU_WARP_SIZE; + +template +void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = DivUp(params->c, params->channels_per_block); + // The number of blocks to compute all the activations in a given instance. + grid.y = DivUp(params->hw, params->hw_per_block); + // The number of instances. + grid.z = params->n; + +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<StreamHandle()>>>( \ + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ + break; + + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { + case 256: + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) + case 128: + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template +Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { + dim3 grid; + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); + grid.z = params->n; + + GroupNormNHWCSumKernel + <<StreamHandle()>>>( + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); + return HIP_CALL(hipGetLastError()); +} + +template +void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = DivUp(params->c, params->channels_per_block); + // The number of blocks to compute all the activations in a given instance. + grid.y = DivUp(params->hw, params->hw_per_block); + // The number of instances. + grid.z = params->n; + +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<StreamHandle()>>>( \ + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ + params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ + params->hw, params->hw_per_block, params->use_silu); \ + break; + + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { + case 256: + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) + case 128: + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template +Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { + dim3 grid; + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); + grid.z = params->n; + + GroupNormNHWCScaleKernel + <<StreamHandle()>>>( + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, + params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, + params->use_silu); + return HIP_CALL(hipGetLastError()); +} + +template +class GroupNormNHWCOp { + public: + Status operator()(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); + auto status = GroupNormNHWCSumOp(params); + ORT_RETURN_IF_ERROR(status); + HIP_RETURN_IF_ERROR(hipGetLastError()); + status = GroupNormNHWCScaleOp(params); + ORT_RETURN_IF_ERROR(status); + HIP_RETURN_IF_ERROR(hipGetLastError()); + return Status::OK(); + } + + Status IsSupported(const GroupNormNHWCTunableParams* params) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), + "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, + ") isn't divisible by the number of vector size: ", VecSize); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && + params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), + "Configuration: Threads (", ThreadsPerBlock, "), vector size (", + VecSize, ") is redundant for the number of channels per group: ", + params->channels_per_block); + + return Status::OK(); + } +}; + +template +Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); + GroupNormNHWCSum(params); + HIP_RETURN_IF_ERROR(hipGetLastError()); + GroupNormNHWCScale(params); + HIP_RETURN_IF_ERROR(hipGetLastError()); + return Status::OK(); +} + +#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ + this->RegisterOp(name{}); \ + this->RegisterOp(name{}); \ + this->RegisterOp(name{}); + +#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 320) + +template +class GroupNormNHWCTunableOp : public TunableOp> { + public: + GroupNormNHWCTunableOp() { + this->RegisterOp(GroupNormNHWCStaticSelection); + ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) + +#ifdef USE_COMPOSABLE_KERNEL + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif // USE_COMPOSABLE_KERNEL + +#ifdef USE_TRITON_KERNEL + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif + } +}; + +#undef ADD_OP_FOR_ALL_VEC_SIZE +#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc new file mode 100644 index 0000000000000..35427a02c631d --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/rocm/nn/conv.h" + +using namespace onnxruntime::rocm; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + NhwcConv, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Conv); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/fused_conv.cc b/onnxruntime/contrib_ops/rocm/fused_conv.cc new file mode 100644 index 0000000000000..4f3be98d97f80 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/fused_conv.cc @@ -0,0 +1,439 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/common/status.h" +#include "core/providers/rocm/nn/conv.h" +#include "core/providers/rocm/rocm_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +namespace { + +// Copied from hipDNN/library/src/hcc_detail/hipdnn_miopen.cpp +miopenStatus_t _miopenAddTensor( + miopenHandle_t handle, + const void* alpha, + const miopenTensorDescriptor_t aDesc, + const void* A, + const void* beta, + const miopenTensorDescriptor_t cDesc, + void* C, + const void* zero_scalar) { + const miopenTensorOp_t tensorOp = miopenTensorOpAdd; + // Using miopenOpTensor to implement Add operator. + // opnd2 = Add ( 0.0 * opnd0, alpha * opnd1 ) + beta * opnd2 + return miopenOpTensor(handle, tensorOp, + zero_scalar, cDesc, C, + alpha, aDesc, A, + beta, cDesc, C); +} + +} // namespace + +template +struct FNVHash { + uint32_t GetValue() const { return value_; } + + void Hash(const void* in_ptr, size_t nbytes) { + auto ptr = reinterpret_cast(in_ptr); + for (size_t i = 0; i < nbytes; ++i) { + value_ ^= ptr[i]; + value_ *= PRIME; + } + } + + template ::value, size_t>::type = 0> + FNVHash& operator<<(const T& pod) { + Hash(&pod, sizeof(pod)); + return *this; + } + + template + FNVHash& operator<<(const std::vector& pod_array) { + for (const auto& pod : pod_array) { + (*this) << pod; + } + return *this; + } + + void HashTensor(miopenTensorDescriptor_t tdesc) { + int size = 0; + miopenGetTensorDescriptorSize(tdesc, &size); + (*this) << size; + std::vector dims(size); + std::vector strides(size); + miopenDataType_t dtype; + miopenGetTensorDescriptor(tdesc, &dtype, dims.data(), strides.data()); + (*this) << dtype; + (*this) << dims; + (*this) << strides; + } + + void HashConvolutionDescriptor(miopenConvolutionDescriptor_t cdesc) { + int spatial_dim = 1; +#if ROCM_VERSION >= 50500 + MIOPEN_CALL(miopenGetConvolutionSpatialDim(cdesc, &spatial_dim)); + std::vector pads{spatial_dim}; + std::vector strides{spatial_dim}; + std::vector dilations{spatial_dim}; + miopenConvolutionMode_t mode; + MIOPEN_CALL(miopenGetConvolutionNdDescriptor(cdesc, spatial_dim, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)); +#else + // Previous versions of MIOpen doesn't provide API to probe the dimension of a + // miopenConvolutionDescriptor_t, so we have to guess. + // This algorithm is based on a specific behavior of miopenGetConvolutionNdDescriptor, + // which fails when requestedSpatialDim > the convolution's spatial dimension + constexpr const int kMaxSpatialDim = 5; + std::vector pads{kMaxSpatialDim}; + std::vector strides{kMaxSpatialDim}; + std::vector dilations{kMaxSpatialDim}; + miopenConvolutionMode_t mode; + bool spatial_dim_guessed = false; + for (int i = 0; i < kMaxSpatialDim; i++) { + if (miopenStatusSuccess == miopenGetConvolutionNdDescriptor( + cdesc, i, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)) { + spatial_dim_guessed = true; + break; + } + } + ORT_ENFORCE(spatial_dim_guessed, "Failed to guess the actual spatial dimension"); + // Remove the extra dimension + pads.resize(spatial_dim); + strides.resize(spatial_dim); + dilations.resize(spatial_dim); +#endif + (*this) << spatial_dim; + (*this) << pads; + (*this) << strides; + (*this) << dilations; + (*this) << mode; + } + + private: + uint32_t value_ = BASIS; +}; + +template +class FusedConv : public onnxruntime::rocm::Conv { + public: + using Base = onnxruntime::rocm::Conv; + FusedConv(const OpKernelInfo& info) : onnxruntime::rocm::Conv(info) { + std::string activation; + ORT_THROW_IF_ERROR(info.GetAttr("activation", &activation)); + ORT_THROW_IF_ERROR(MapMode(activation)); + MIOPEN_CALL_THROW(miopenCreateActivationDescriptor(&activation_desc_)); + MIOPEN_CALL_THROW(miopenSetActivationDescriptor(activation_desc_, activation_mode_, 0.0, 0.0, 0.0)); + MIOPEN_CALL_THROW(miopenCreateOperatorArgs(&fusion_args_)); + } + + ORT_DISALLOW_COPY_AND_ASSIGNMENT(FusedConv); + + ~FusedConv() { + if (activation_desc_) { + MIOPEN_CALL_THROW(miopenDestroyActivationDescriptor(activation_desc_)); + activation_desc_ = nullptr; + } + + if (fusion_args_) { + miopenDestroyOperatorArgs(fusion_args_); + } + } + + Status ComputeInternal(OpKernelContext* context) const override { + std::lock_guard lock(Base::s_.mutex); + + ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); + if (Base::s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + + bool has_z = nullptr != Base::s_.z_data; + bool has_b = nullptr != Base::s_.b_data; + auto factory = [this](FusedConvFusionData& fusion) { + return this->DoCreateFusionDesc(this->Node().Name(), fusion); + }; + auto& cached_item = plan_cache_.FindOrCreateFusionPlanCache(Hash(), + factory); + bool should_try_fusion_api = cached_item.Validate(this->GetMiopenHandle(context)); + + typedef typename onnxruntime::rocm::ToHipType::MappedType HipT; + const auto alpha = onnxruntime::rocm::Consts::One; + const auto beta = onnxruntime::rocm::Consts::Zero; + IAllocatorUniquePtr workspace = Base::GetWorkSpace(context->GetComputeStream()); + miopenStatus_t fusion_status = miopenStatusNotInitialized; + + if (should_try_fusion_api) { + auto& fusion_info = *cached_item.fusion; + MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsConvForward(fusion_args_, + fusion_info.conv_op, + &alpha, + &beta, + Base::s_.w_data)); + if (has_z) { + MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, + fusion_info.bias_z_op, + &alpha, + &beta, + Base::s_.z_data)); + } + if (has_b) { + MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, + fusion_info.bias_b_op, + &alpha, + &beta, + Base::s_.b_data)); + } + if (activation_desc_) { + const float relu_notused = 0.0; + MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsActivForward(fusion_args_, + fusion_info.act_op, + &alpha, + &beta, + relu_notused, + relu_notused, + relu_notused)); + } + fusion_status = miopenExecuteFusionPlan(this->GetMiopenHandle(context), + fusion_info.plan, + Base::s_.x_tensor, + Base::s_.x_data, + Base::s_.y_tensor, + Base::s_.y_data, + fusion_args_); + } + if (miopenStatusSuccess != fusion_status) { + MIOPEN_RETURN_IF_ERROR(miopenConvolutionForward(this->GetMiopenHandle(context), + &alpha, + Base::s_.x_tensor, + Base::s_.x_data, + Base::s_.w_desc, + Base::s_.w_data, + Base::s_.conv_desc, + Base::s_.fwd_algo, + &beta, + Base::s_.y_tensor, + Base::s_.y_data, + workspace.get(), + Base::s_.workspace_bytes)); + if (has_b) { + MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), + &alpha, Base::s_.b_tensor, Base::s_.b_data, + &alpha, Base::s_.y_tensor, Base::s_.y_data, + &beta)); + } + if (has_z) { + MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), + &alpha, Base::s_.z_tensor, Base::s_.z_data, + &alpha, Base::s_.y_tensor, Base::s_.y_data, + &beta)); + } + MIOPEN_RETURN_IF_ERROR(miopenActivationForward(this->GetMiopenHandle(context), + activation_desc_, + &alpha, + Base::s_.y_tensor, + Base::s_.y_data, + &beta, + Base::s_.y_tensor, + Base::s_.y_data)); + } + if (Base::s_.post_slicing_required) { + ORT_RETURN_IF_ERROR(onnxruntime::rocm::SliceOutUnwantedOutputSection( + this->Stream(context), + Base::s_.y_data, + Base::s_.y_dims_with_adjusted_pads, + Base::s_.Y->MutableDataRaw(), + Base::s_.y_dims.GetDims(), + Base::s_.slice_starts, + Base::s_.slice_ends, + Base::s_.slice_axes, + Base::s_.element_size)); + } + return Status::OK(); + } + + private: + Status MapMode(const std::string& activaton_mode) { + if (activaton_mode == "Relu") { + activation_mode_ = miopenActivationMode_t::miopenActivationRELU; + } else { + return ORT_MAKE_STATUS( + StatusCategory::ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, + "unsupported conv activation mode \"", activaton_mode, "\""); + } + return Status::OK(); + } + miopenActivationMode_t activation_mode_; + miopenActivationDescriptor_t activation_desc_ = nullptr; + + miopenOperatorArgs_t fusion_args_ = nullptr; + + // MIOpen Fusion API + // TODO: create one fusion descriptor shared by multiple FusedConv + // objects + // + // Considerations: + // How to determine two FusedConv objects may share the same fusion + // descriptor? Hashing x_tensor,conv_desc, etc.? + struct FusedConvFusionData { + miopenFusionPlanDescriptor_t plan = nullptr; + miopenFusionOpDescriptor_t conv_op = nullptr; + miopenFusionOpDescriptor_t bias_b_op = nullptr; + miopenFusionOpDescriptor_t bias_z_op = nullptr; + miopenFusionOpDescriptor_t act_op = nullptr; + + // TODO: There is a potential problem. miopenHandle_t may be destroyed and + // re-created later, sharing the same address. Currently there is any way + // to detect it? + mutable std::unordered_set compiled_on; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FusedConvFusionData); + + FusedConvFusionData() {} + ~FusedConvFusionData() { + if (plan) { + miopenDestroyFusionPlan(plan); + } + } + }; + + struct FusionPlanCacheItem { + std::unique_ptr fusion; + Status creation_result; + // TODO: Add a timestamp for eviction + // std::chrono::time_point last_access; + + FusionPlanCacheItem() {} + + miopenStatus_t CompileOnHandle(miopenHandle_t handle) const { + if (!fusion->plan) { + return miopenStatusNotInitialized; + } + auto iter = fusion->compiled_on.find(handle); + if (iter != fusion->compiled_on.end()) { + return miopenStatusSuccess; + } + auto ret = miopenCompileFusionPlan(handle, fusion->plan); + if (miopenStatusSuccess == ret) { + fusion->compiled_on.insert(handle); + } else { + return ret; + } + return miopenStatusSuccess; + } + + bool Validate(miopenHandle_t handle) const { + if (Status::OK() != creation_result) { + return false; + } + if (!fusion || !fusion->plan) { + return false; + } + auto compiling_status = CompileOnHandle(handle); + if (miopenStatusSuccess != compiling_status) { + return false; + } + + return true; + } + }; + + struct FusionPlanCache { + mutable std::mutex mutex; + using HashKey = uint32_t; + std::unordered_map cache_directory_; + + FusionPlanCache() { + } + + FusionPlanCacheItem& FindOrCreateFusionPlanCache(HashKey key, + std::function factory) { + std::lock_guard lock(mutex); + auto iter = cache_directory_.find(key); + if (iter == cache_directory_.end()) { + cache_directory_[key].fusion = std::make_unique(); + cache_directory_[key].creation_result = factory(*cache_directory_[key].fusion); + if (Status::OK() != cache_directory_[key].creation_result) { + cache_directory_[key].fusion.reset(); + } + } + return cache_directory_[key]; + } + }; + + static FusionPlanCache plan_cache_; + + Status DoCreateFusionDesc(const std::string& node_name, FusedConvFusionData& fusion) const { + bool has_z = nullptr != Base::s_.z_data; + bool has_b = nullptr != Base::s_.b_data; + MIOPEN_RETURN_IF_ERROR(miopenCreateFusionPlan(&fusion.plan, + miopenVerticalFusion, + Base::s_.x_tensor)); + auto status = miopenCreateOpConvForward(fusion.plan, &fusion.conv_op, Base::s_.conv_desc, Base::s_.w_desc); + if (status == miopenStatusUnsupportedOp) { + auto msg = MakeString("MIOpen does not support the conv fusion for node \"", + node_name, "\", fallback to unfused implementation."); + LOGS_DEFAULT(WARNING) << msg; + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, msg); + } + MIOPEN_RETURN_IF_ERROR(status); + + if (has_z) { + MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, + &fusion.bias_z_op, + Base::s_.z_tensor)); + } + if (has_b) { + MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, + &fusion.bias_b_op, + Base::s_.b_tensor)); + } + if (activation_desc_) { + MIOPEN_RETURN_IF_ERROR(miopenCreateOpActivationForward(fusion.plan, + &fusion.act_op, + activation_mode_)); + } + return Status::OK(); + } + + uint32_t Hash() const { + FNVHash hash; + bool has_z = nullptr != Base::s_.z_data; + bool has_b = nullptr != Base::s_.b_data; + hash.HashTensor(Base::s_.x_tensor); + hash.HashConvolutionDescriptor(Base::s_.conv_desc); + hash.HashTensor(Base::s_.w_desc); + if (has_z) { + hash.HashTensor(Base::s_.z_tensor); + } + if (has_b) { + hash.HashTensor(Base::s_.b_tensor); + } + if (activation_desc_) { + hash << static_cast(activation_mode_); + } + return hash.GetValue(); + } +}; + +template +typename FusedConv::FusionPlanCache FusedConv::plan_cache_; + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + FusedConv, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + FusedConv); + +REGISTER_KERNEL_TYPED(float); +REGISTER_KERNEL_TYPED(MLFloat16); +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu new file mode 100644 index 0000000000000..3539f32252944 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/common/float16.h" +#include "core/providers/rocm/rocm_kernel.h" +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; +using namespace onnxruntime::rocm::tunable::blas; + +class GemmFloat8 final : public RocmKernel { + public: + GemmFloat8(const OpKernelInfo& info) : RocmKernel(info) { + transA_ = info.GetAttrOrDefault("transA", 0); + transB_ = info.GetAttrOrDefault("transB", 0); + dtype_ = info.GetAttrOrDefault("dtype", onnx::TensorProto_DataType_FLOAT16); + alpha_ = info.GetAttrOrDefault("alpha", 1); + beta_ = info.GetAttrOrDefault("beta", 0); + } + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: +#if !defined(DISABLE_FLOAT8_TYPES) + template + Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const; + template + Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const; + + template + [[nodiscard]] inline auto* GetOp() const { + using OpT = GemmFloat8TunableOp; + if (tunable_op_) { + return static_cast(tunable_op_.get()); + } + + auto create = std::make_unique(); // avoid new + tunable_op_ = std::shared_ptr(create.release(), [](void* ptr) { + auto release = std::unique_ptr(); // avoid delete + release.reset(static_cast(ptr)); + }); + + return static_cast(tunable_op_.get()); + } +#endif + + float alpha_; + float beta_; + bool transA_; + bool transB_; + int64_t dtype_; + + // fully type erased + mutable std::shared_ptr tunable_op_; +}; + +Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { +#if defined(DISABLE_FLOAT8_TYPES) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DISABLE_FLOAT8_TYPES"); +#else + const Tensor* A = ctx->Input(0); + const Tensor* B = ctx->Input(1); + const Tensor* C = ctx->Input(2); // bias + const Tensor* scale_a = ctx->Input(3); + const Tensor* scale_b = ctx->Input(4); + const Tensor* scale_y = ctx->Input(5); + + auto a_shape = A->Shape(); + auto b_shape = B->Shape(); + ORT_ENFORCE(a_shape.NumDimensions() == 2); + ORT_ENFORCE(b_shape.NumDimensions() == 2); + + auto m = !transA_ ? a_shape[0] : a_shape[1]; + auto k = !transA_ ? a_shape[1] : a_shape[0]; + ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatible + auto n = !transB_ ? b_shape[1] : b_shape[0]; + + TensorShapeVector output_shape = {m, n}; + Tensor* Y = ctx->Output(0, output_shape); + + ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose"); + ORT_ENFORCE(dtype_ == onnx::TensorProto_DataType_FLOAT16, "ROCm GemmFloat8 only supports output float16"); + ORT_ENFORCE(C == nullptr, "ROCm GemmFloat8 does not support bias input"); + ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling"); + + if (A->IsDataType()) { + return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); + } else if (A->IsDataType()) { + return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); + } else if (B->IsDataType()) { + return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); + } else if (B->IsDataType()) { + return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unhandled type combination of GemmFloat8"); +#endif +} + +#if !defined(DISABLE_FLOAT8_TYPES) +template +Status GemmFloat8::ComputeFp8Fp16Fp16( + OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const { + ORT_ENFORCE(A->IsDataType() && scale_a->IsDataType() && B->IsDataType()); + + onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; + params.tuning_ctx = GetTuningContext(); + params.stream = ctx->GetComputeStream(); + params.handle = GetHipblasHandle(ctx); + params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + + params.m = m; + params.n = n; + params.k = k; + + params.a = static_cast(A->DataRaw()); + params.lda = transA_ ? m : k; + params.scale_a = alpha_; + params.scale_a_dev = static_cast(scale_a->DataRaw()); + + params.b = static_cast(B->DataRaw()); + params.ldb = transB_ ? k : n; + params.scale_b = 1.0f; // NOTE: not used + params.scale_b_dev = nullptr; // NOTE: not used + + params.c = static_cast(C->MutableDataRaw()); + params.ldc = n; + params.scale_c = 1.0f; // NOTE: not implemented + params.scale_c_dev = nullptr; // NOTE: not implemented + + if (!transA_ && !transB_) { + return (*GetOp())(¶ms); + } else if (transA_ && !transB_) { + ORT_NOT_IMPLEMENTED("transA is not implemented"); + } else if (!transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transB is not implemented"); + } else if (transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); +} + +template +Status GemmFloat8::ComputeFp16Fp8Fp16( + OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const { + ORT_ENFORCE(A->IsDataType() && B->IsDataType() && scale_b->IsDataType()); + + onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; + params.tuning_ctx = GetTuningContext(); + params.stream = ctx->GetComputeStream(); + params.handle = GetHipblasHandle(ctx); + params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + + params.m = m; + params.n = n; + params.k = k; + + params.a = static_cast(A->DataRaw()); + params.lda = transA_ ? m : k; + params.scale_a = 1.0f; // NOTE: not used + params.scale_a_dev = nullptr; // NOTE: not used + + params.b = static_cast(B->DataRaw()); + params.ldb = transB_ ? k : n; + params.scale_b = alpha_; + params.scale_b_dev = static_cast(scale_b->DataRaw()); + + params.c = static_cast(C->MutableDataRaw()); + params.ldc = n; + params.scale_c = 1.0f; // NOTE: not implemented + params.scale_c_dev = nullptr; // NOTE: not implemented + + if (!transA_ && !transB_) { + return (*GetOp())(¶ms); + } else if (transA_ && !transB_) { + ORT_NOT_IMPLEMENTED("transA is not implemented"); + } else if (!transA_ && transB_) { + return (*GetOp())(¶ms); + } else if (transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); +} +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#else +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#endif + +ONNX_OPERATOR_KERNEL_EX( + GemmFloat8, + kMSDomain, + 1, + kRocmExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) + .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) + .TypeConstraint("TR", BuildKernelDefConstraints()) + .TypeConstraint("TS", BuildKernelDefConstraints()), + GemmFloat8); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh new file mode 100644 index 0000000000000..b545eb1f2a149 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#if defined(USE_COMPOSABLE_KERNEL) + +#include "core/providers/rocm/composable_kernel_common.h" + +#include "ck/ck.hpp" +#include "ck/utility/functional3.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#endif + +#if !defined(DISABLE_FLOAT8_TYPES) +#include "core/common/float8.h" +#endif +#include "core/providers/rocm/tunable/gemm_common.h" + +namespace onnxruntime { +namespace rocm { +namespace tunable { + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +constexpr bool always_false = false; + +template +struct Scale { + constexpr const static bool is_pack2_invocable = true; + constexpr const static bool is_pack4_invocable = true; + + explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {} + + template + __forceinline__ __host__ __device__ Y fast_type_convert(X x) const { + static_assert(always_false, "not implemented"); + (void)x; + } + + template <> + __forceinline__ __host__ __device__ ck::half_t fast_type_convert(ck::f8_t x) const { + // https://github.com/ROCmSoftwarePlatform/triton/blob/0cc3f8b84a16892396f6e08a04991034d67e32b1/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L220-L233 + constexpr const uint16_t mask = 0x7fff; + constexpr const uint16_t sign_mask = 0x8000; + constexpr const uint16_t exp_compensate = []() { + if constexpr (std::is_same_v) { + return 0x2000; + } else if constexpr (std::is_same_v) { + return 0x1c00; + } + }(); + + uint8_t x_u8 = reinterpret_cast(x); + uint16_t x_u16 = static_cast(x_u8) << 8; + uint16_t exp = (x_u16 & mask) >> 1; + uint16_t y = (x_u16 & sign_mask) | (exp + exp_compensate); + return reinterpret_cast(y); + } + + __forceinline__ __host__ __device__ void operator()(ck::half_t& y, const ck::f8_t& x) const { + float scale = scale_value_ * (*dev_scale_ptr_); + y = ck::type_convert(scale * fast_type_convert(x)); + } + + __forceinline__ __host__ __device__ void operator()(ck::half2_t& ys, const ck::f8x2_t& xs) const { + float scale = scale_value_ * (*dev_scale_ptr_); + constexpr const uint32_t mask = 0x7fff7fff; + constexpr const uint32_t sign_mask = 0x80008000; + constexpr const uint32_t exp_compensate = []() { + if constexpr (std::is_same_v) { + return 0x20002000; + } else if constexpr (std::is_same_v) { + return 0x1c001c00; + } + }(); + + const uchar2& x2_u8 = reinterpret_cast(xs); + uchar4 x{0, x2_u8.x, 0, x2_u8.y}; + uint32_t x_u32 = reinterpret_cast(x); + + uint32_t exp = (x_u32 & mask) >> 1; + uint32_t v = (x_u32 & sign_mask) | (exp + exp_compensate); + ys = scale * reinterpret_cast(v); + } + + __forceinline__ __host__ __device__ void operator()(ck::half4_t& ys, const ck::f8x4_t& xs) const { + float scale = scale_value_ * (*dev_scale_ptr_); + constexpr const uint32_t mask = 0x7fff7fff; + constexpr const uint32_t sign_mask = 0x80008000; + constexpr const uint32_t exp_compensate = []() { + if constexpr (std::is_same_v) { + return 0x20002000; + } else if constexpr (std::is_same_v) { + return 0x1c001c00; + } + }(); + + uint32_t xs_u32 = reinterpret_cast(xs); + uint32_t x_u32_0 = __byte_perm(xs_u32, 0, 0x1504); + uint32_t x_u32_1 = __byte_perm(xs_u32, 0, 0x3726); + uint32_t exp_0 = (x_u32_0 & mask) >> 1; + uint32_t exp_1 = (x_u32_1 & mask) >> 1; + uint32_t v_0 = (x_u32_0 & sign_mask) | (exp_0 + exp_compensate); + uint32_t v_1 = (x_u32_1 & sign_mask) | (exp_1 + exp_compensate); + uint64_t v = v_0 | uint64_t(v_1) << 32; + ys = scale * reinterpret_cast(v); + } + + float scale_value_; + const float* const dev_scale_ptr_; +}; +#endif + +namespace blas { + +template +struct GemmFloat8Params : tunable::OpParams { + std::string Signature() const override { + return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); + } + + hipblasHandle_t handle; + BlasOp opa; + BlasOp opb; + int64_t m; + int64_t n; + int64_t k; + float scale_a{}; + const float* scale_a_dev{}; + const TA* a; + int64_t lda; + float scale_b{}; + const float* scale_b_dev{}; + const TB* b; + int64_t ldb; + TC* c; + float scale_c{}; + const float* scale_c_dev{}; + int64_t ldc; +}; + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Nop = ck::tensor_operation::element_wise::PassThrough; + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, Nop, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, Nop, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, Nop>>>& instances); + +template +auto CreateOp(float scale, const float* dev_scale) { + if constexpr (std::is_same_v) { + return Scale(scale, dev_scale); + } else if constexpr (std::is_same_v) { + return Scale(scale, dev_scale); + } else { + return Nop{}; + } +} + +template +auto GetCKF8SplitKGemmTypeStringAndOps() { + using CKTA = typename CKDataTypeAdaptor::type; + using CKTB = typename CKDataTypeAdaptor::type; + using CKTC = typename CKDataTypeAdaptor::type; + + using CKLayoutA = typename CKBlasOpAdaptor::type; + using CKLayoutB = typename CKBlasOpAdaptor::type; + + using OpA = std::conditional_t, Scale, Nop>; + using OpB = std::conditional_t, Scale, Nop>; + using OpC = std::conditional_t, Scale, Nop>; + + using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< + CKLayoutA, CKLayoutB, Row, + CKTA, CKTB, CKTC, + OpA, OpB, OpC>; + + std::vector>>> ret; + + for (auto num_split : {1, 4, 16, 64}) { + std::vector> instances{}; + if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { + add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances); + } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { + add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances); + } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { + add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances); + } else { + static_assert(always_false, "no instances for the type combination"); + LOGS_DEFAULT(FATAL) << "no instances for the type combination"; + } + for (auto&& impl : instances) { + auto type_string = std::to_string(ret.size()) + "_" + impl->GetTypeString() + "_SplitK" + std::to_string(num_split); + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmFloat8Params* params) -> Status { + OpA op_a = CreateOp(params->scale_a, params->scale_a_dev); + OpB op_b = CreateOp(params->scale_b, params->scale_b_dev); + OpC op_c = CreateOp(params->scale_c, params->scale_c_dev); + + auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, + params->m, params->n, params->k, + params->lda, params->ldb, params->ldc, + op_a, op_b, op_c, num_split); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); + } + } + return ret; +} + +#endif // USE_COMPOSABLE_KERNEL + +template +class GemmFloat8TunableOp : public TunableOp> { + public: + GemmFloat8TunableOp() { +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#else + ORT_ENFORCE(false, "CK is required to support GemmFloat8 computing"); +#endif // USE_COMPOSABLE_KERNEL + } +}; + +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu new file mode 100644 index 0000000000000..4c691dd18f2e9 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +namespace internal { +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances); +} // namespace internal + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); +} + +namespace internal { +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances); + +// TODO: The first try of derivation does not going well due to various constraints. +// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( +// std::vector, PassThrough, PassThrough>>>& instances); + +// TODO: The first try of derivation does not going well due to various constraints. +// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( +// std::vector, PassThrough, PassThrough>>>& instances); +} // namespace internal + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, PassThrough, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); + // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: +} + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, PassThrough, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); + // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: +} + +namespace internal { +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances); +} // namespace internal + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); +} + +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu new file mode 100644 index 0000000000000..49463e58886f8 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + +// The derived version is simply double BBlockTransferSrcScalarPerVector and adjust other values correspondingly +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 8, 4, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 8, 4, 32, 32, 3, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 8, 4, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 12, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 16, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 8, 4, 32, 32, 3, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 8, 4, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu new file mode 100644 index 0000000000000..236e5555051fc --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu new file mode 100644 index 0000000000000..1a0d45df82a71 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu new file mode 100644 index 0000000000000..a0628802ec09e --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +template +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); +} + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc new file mode 100644 index 0000000000000..7dbb24463961e --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -0,0 +1,347 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/rocm/rocm_common.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace rocm { +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasAdd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Irfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Irfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Irfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMulConj); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasDropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskDropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskBiasDropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, NGramRepeatBlock); + +// These ops were experimental ops in onnx domain which have been removed now. We add them here as +// contrib ops to maintain backward compatibility +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Affine); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Affine); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Crop); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ImageScaler); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ImageScaler); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LongformerAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LongformerAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Sampling); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_float, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, double_double_double, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_MLFloat16, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Inverse); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GemmFloat8); + +#ifdef ENABLE_ATEN +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); +#endif + +#ifdef ENABLE_TRAINING_OPS +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather); +#endif + +#ifdef ORT_USE_NCCL +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll); +#endif + +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} + +// clang-format off +Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // These ops were experimental ops in onnx domain which have been removed now. We add them here as + // contrib ops to maintain backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // TransposedMatMul is still here for backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + +#ifdef ENABLE_ATEN + BuildKernelCreateInfo, +#endif + +#ifdef ENABLE_TRAINING_OPS + // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or + // 2). this is needed by inference for other purpose. + BuildKernelCreateInfo, +#endif + +#ifdef ORT_USE_NCCL + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif + + }; + + for (auto& function_table_entry : function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); + } + } + + return Status::OK(); +} +// clang-format on + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h new file mode 100644 index 0000000000000..db9a5d4fcd83e --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +Status RegisterRocmContribKernels(KernelRegistry& kernel_registry); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 130dd0c25a880..a5ab63d74df24 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -165,7 +165,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let query_pos = m + local_id.y + past_sequence_length;\n" << " let key_pos = n + local_id.x;\n" << " if (key_pos > query_pos) {\n" - << " sum = -3.4028234663852886e+38; // Set to very negative value for masking\n" + << " sum = -3.40282e+38; // Set to very negative value for masking\n" << " }\n"; } @@ -272,7 +272,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let effective_seq_length = seq_causal_length;\n"; } shader.MainFunctionBody() - << "var thread_max_vector = f32_val_t(-3.4028234663852886e+38f);\n" + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" << " let actual_pos = local_offset + i + start_offset;\n" << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" @@ -289,7 +289,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } else if (use_smooth_softmax_) { shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n"; } else { - shader.MainFunctionBody() << "var max_value = f32(-3.4028234663852886e+38f);\n"; + shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n"; } shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 2a67dfdb07912..606dbfde15c2c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -421,7 +421,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co indirect_buffer_ptr, tile_size)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr)); } if (parameters.sequence_length_ > 1) { @@ -571,8 +571,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {static_cast(head_size_vec)}, - {static_cast(half_rotary_embedding_dim_vec)}, + {head_size_vec}, + {half_rotary_embedding_dim_vec}, {present_sequence_length}, {tile_size}, {static_cast(dispatch_size)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index ff8e4ecc08bab..a5922ec9512fd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -26,7 +26,7 @@ fn get_total_sequence_length() -> u32 { #if is_fp16 const min_value = q_element_t(-65504.0); #else -const min_value = q_element_t(-3.4028234663852886e+38f); +const min_value = q_element_t(-3.402823e+38f); #endif // For max performance max_k_step should be the same as sg_size, however we might run out of registers diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template index ac9a157492007..c6f768beffa0f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template @@ -93,7 +93,7 @@ $MAIN { if (local_idx == 0u) { // Calculate the max and sum in current split. - var l_max = f32(-3.4028234663852886e+38f); + var l_max = f32(-3.402823e+38f); var l_sum = f32(0); for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { l_max = max(l_max, f32(tile_qk[i])); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template index a113e96130985..37cf7e8f11b1f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template @@ -54,7 +54,7 @@ $MAIN { // Calculate the global max and sum in qk. if (head_idx < uniforms.num_heads) { - var g_max = f32(-3.4028234663852886e+38f); + var g_max = f32(-3.402823e+38f); var g_sum = f32(0); for (var i = 0u; i < num_total_seq_length_tile; i++) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 416a895e61745..05717fd2fe686 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -128,8 +128,8 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {static_cast(head_size_vec)}, - {static_cast(half_rotary_embedding_dim_vec)}, + {head_size_vec}, + {half_rotary_embedding_dim_vec}, {static_cast(dispatch_size)}, }) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); diff --git a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template index 6e0d4c7299793..1214777009a8d 100644 --- a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template @@ -18,7 +18,7 @@ const K: u32 = k; #if is_fp16 const MAX_FLOAT: f16 = 65504.0; #else -const MAX_FLOAT: f32 = 3.4028234663852886e+38; +const MAX_FLOAT: f32 = 3.402823466e+38; #endif var shared_vals: array; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 1c80d83f99feb..e77496b6e8196 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -499,7 +499,8 @@ class PlannerImpl { /*! \brief Given a tensor-type, return the size of an element of the tensor. */ static size_t GetElementSize(const DataType& tensor_type) { - MLDataType ml_data_type = DataTypeImpl::GetDataType(*tensor_type); + const TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); + MLDataType ml_data_type = DataTypeImpl::TypeFromProto(type_proto); const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); ORT_ENFORCE(nullptr != tensor_type_base); MLDataType elt_type = tensor_type_base->GetElementType(); diff --git a/onnxruntime/core/framework/ort_value_name_idx_map.h b/onnxruntime/core/framework/ort_value_name_idx_map.h index 6035dc4e85242..76e7e369514d4 100644 --- a/onnxruntime/core/framework/ort_value_name_idx_map.h +++ b/onnxruntime/core/framework/ort_value_name_idx_map.h @@ -33,7 +33,7 @@ class OrtValueNameIdxMap { common::Status GetIdx(std::string_view name, int& idx) const { idx = -1; - auto it = map_.find(name); + auto it = map_.find(std::string(name)); if (it == map_.end()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not find OrtValue with name '", name, "'"); } diff --git a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h index 94ef87fb069af..bc52a45adfd43 100644 --- a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h +++ b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h @@ -83,8 +83,7 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nchw_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input if (rank < 3) { - *nhwc_tp.mutable_tensor_type()->mutable_shape() = nchw_shape; - return; + fail_shape_inference("Output tensor must have at least 3 dimensions"); } // Convert output shape from N, C, H {, W, ...} to N, H {, W, ...}, C. @@ -106,8 +105,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nhwc_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input. if (rank < 3) { - *nchw_tp.mutable_tensor_type()->mutable_shape() = nhwc_shape; - return; + fail_shape_inference( + "Tensor must have at least 3 dimensions to convert between channels first and channels last."); } // Convert input shape from {N, H, W, ..., C} to {N, C, H, W, ...}. diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 1eb03af3befa4..6cbbdd4e0a7ef 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -81,10 +81,6 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(captureState); } -void Telemetry::LogCompileModel(uint32_t session_id) const { - ORT_UNUSED_PARAMETER(session_id); -} - void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { ORT_UNUSED_PARAMETER(session_id); diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index 9c2859f7634b6..b60345e1b8a80 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -66,8 +66,6 @@ class Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const; - virtual void LogCompileModel(uint32_t session_id) const; - virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 693e265af46b1..2e5d334856278 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -334,20 +334,6 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio } } -void WindowsTelemetry::LogCompileModel(uint32_t session_id) const { - if (global_register_count_ == 0 || enabled_ == false) - return; - - TraceLoggingWrite(telemetry_provider_handle, - "CompileModel", - TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), - TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), - TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), - // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), - TraceLoggingUInt32(session_id, "sessionId")); -} - void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { if (global_register_count_ == 0 || enabled_ == false) diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 044feec071223..261d14a7fed8c 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -59,8 +59,6 @@ class WindowsTelemetry : public Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const override; - void LogCompileModel(uint32_t session_id) const override; - void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 26144e6ba3995..ef977161bcc37 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -126,7 +126,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation -JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.4028234663852886e+38f, max, -3.4028234663852886e+38f) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index d148c4191d5d7..e2a8005aba1da 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1407,30 +1407,9 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } // Find inputs and outputs of the subgraph - std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs; - - // These maps store the inputs and outputs of the subgraph. - // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration - // to determine the final inputs and outputs of the subgraph. - std::unordered_map fused_inputs, fused_outputs; - - // This map stores the node's output that will be consumed by another node outside of this subgraph. - // So the node's output should be put into the subgraph's output list. - std::unordered_map fused_outputs_to_add; - - // This map stores the node's output that is original graph's output. - // So the node's output should be put into the subgraph's output list. - std::unordered_map graph_outputs_to_add; - + std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; - - // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', - // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. - // Items added earlier receive a smaller order index than items added later. - // When constructing the final sub_graph's input or output lists, entries with smaller - // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -1449,7 +1428,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs.insert({input, input_order++}); + fused_inputs[input] = input_order++; } } @@ -1464,7 +1443,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs.insert({input, input_order++}); + fused_inputs[input] = input_order++; } } @@ -1485,32 +1464,38 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - - if (node_set.find(node_idx) == node_set.end()) { - // This output will be consumed by another node outside of this subgraph. - // So the output should be put into the subgraph's output list. - fused_outputs_to_add.insert({output, output_order++}); + if (node_set.find(node_idx) != node_set.end()) { + const auto& iter = fused_inputs.find(output); + if (iter != fused_inputs.end()) { + fused_inputs.erase(iter); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } + fused_outputs[output] = output_order++; + } + } else { + fused_outputs_to_add[output] = output_order++; } } - } - - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - // Only when output is neither in input list nor erased list, - // and the output is consumed by another node, add the output to output list - fused_outputs.insert({output, output_order++}); + } else { + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); } - } + // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list + else if (erased.find(output) == erased.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - // This output is the graph's output. - // So the output should be put into the subgraph's output list. - graph_outputs_to_add.insert({output, output_order++}); + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + fused_outputs[output] = output_order++; + } + } } } } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 0bb3accb4d754..4d183b95bd938 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -76,9 +76,6 @@ Status BaseOpBuilder::ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper, return CheckHtpDataTypes(input_qnn_dtypes, output_qnn_dtypes); } else if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) { return CheckGpuDataTypes(input_qnn_dtypes, output_qnn_dtypes); - } else if (IsIrBackend(qnn_model_wrapper.GetQnnBackendType())) { - // TODO: CheckIrDataTypes - return Status::OK(); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Only support backend: CPU, HTP and GPU"); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index 9f28e2609faa1..f3d81d7d2fdd7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -574,10 +574,6 @@ bool QnnOpConfigWrapper::CreateQnnGraphOp(const QNN_INTERFACE_VER_TYPE& qnn_inte return true; } -bool IsIrBackend(QnnBackendType backend_type) { - return backend_type == QnnBackendType::SERIALIZER; -} - bool IsNpuBackend(QnnBackendType backend_type) { return backend_type == QnnBackendType::HTP || backend_type == QnnBackendType::DSP; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 77508f3934a20..42f4d7bb60f34 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -96,8 +96,6 @@ enum class QnnBackendType : uint8_t { SERIALIZER, }; -bool IsIrBackend(QnnBackendType backend_type); - bool IsCpuBackend(QnnBackendType backend_type); bool IsNpuBackend(QnnBackendType backend_type); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 8973a4efa8ba1..85901ab6fdfec 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -222,14 +222,14 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); + const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); + const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index e5b48da33fbc3..cd0c0e4bffdb5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2035,30 +2035,9 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } // Find inputs and outputs of the subgraph - std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs; - - // These maps store the inputs and outputs of the subgraph. - // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration - // to determine the final inputs and outputs of the subgraph. - std::unordered_map fused_inputs, fused_outputs; - - // This map stores the node's output that will be consumed by another node outside of this subgraph. - // So the node's output should be put into the subgraph's output list. - std::unordered_map fused_outputs_to_add; - - // This map stores the node's output that is original graph's output. - // So the node's output should be put into the subgraph's output list. - std::unordered_map graph_outputs_to_add; - + std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; - - // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', - // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. - // Items added earlier receive a smaller order index than items added later. - // When constructing the final sub_graph's input or output lists, entries with smaller - // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -2077,7 +2056,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs.insert({input, input_order++}); + fused_inputs[input] = input_order++; } } @@ -2092,7 +2071,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs.insert({input, input_order++}); + fused_inputs[input] = input_order++; } } @@ -2113,32 +2092,38 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - - if (node_set.find(node_idx) == node_set.end()) { - // This output will be consumed by another node outside of this subgraph. - // So the output should be put into the subgraph's output list. - fused_outputs_to_add.insert({output, output_order++}); + if (node_set.find(node_idx) != node_set.end()) { + const auto& iter = fused_inputs.find(output); + if (iter != fused_inputs.end()) { + fused_inputs.erase(iter); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } + fused_outputs[output] = output_order++; + } + } else { + fused_outputs_to_add[output] = output_order++; } } - } - - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - // Only when output is neither in input list nor erased list, - // and the output is consumed by another node, add the output to output list - fused_outputs.insert({output, output_order++}); + } else { + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); } - } + // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list + else if (erased.find(output) == erased.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - // This output is the graph's output. - // So the output should be put into the subgraph's output list. - graph_outputs_to_add.insert({output, output_order++}); + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + fused_outputs[output] = output_order++; + } + } } } } diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc index 9948069c6779b..85096d0e262d7 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc @@ -78,8 +78,8 @@ bool ClipOpBuilder::HandleBuildOp(vsi::npu::GraphEP* graph_ep, LOGS_DEFAULT(INFO) << "Creating Clip Op."; if (node_unit.SinceVersion() <= 6) { NodeAttrHelper helper(node_unit.GetNode()); - auto min = helper.Get("min", -3.4028234663852886e+38f); - auto max = helper.Get("max", 3.4028234663852886e+38f); + auto min = helper.Get("min", -3.402e+38f); + auto max = helper.Get("max", 3.402e+38f); auto op = graph_ep->GetGraph()->CreateOperation(min, max); (*op).BindInputs(inputs).BindOutputs(outputs); graph_ep->GetOps().push_back(std::move(op)); diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 3e1b87821fe2f..b3eb4b5061423 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -13,7 +13,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, - WebGpuDevice, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), OrtMemTypeDefault)), buffer_manager_{buffer_manager}, mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 74b3d669fcf3b..7c38b4557e078 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -11,11 +11,6 @@ namespace webgpu { class BufferManager; -inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, - OrtDevice::MemType::DEFAULT, - OrtDevice::VendorIds::NONE, - 0}; - class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index d1a2011c8e191..ebe71c6ccfacd 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -6,25 +6,22 @@ namespace onnxruntime { namespace webgpu { - -ComputeContextBase::ComputeContextBase(WebGpuContext& webgpu_context, - const WebGpuExecutionProvider& ep, - const OpKernel& op_kernel) +ComputeContext::ComputeContext(OpKernelContext& kernel_context, + const OpKernel& op_kernel, + const WebGpuExecutionProvider& ep, + WebGpuContext& webgpu_context) : webgpu_context_{webgpu_context}, - ep_{ep}, - op_kernel_{op_kernel} { + kernel_context_{kernel_context}, + op_kernel_{op_kernel}, + ep_{ep} { } -const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(const ComputeContextBase& context) { +const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const ComputeContext& context) { return context.ep_.BufferManager(); } -ComputeContext::ComputeContext(WebGpuContext& webgpu_context, - const WebGpuExecutionProvider& ep, - const OpKernel& op_kernel, - OpKernelContext& kernel_context) - : ComputeContextBase(webgpu_context, ep, op_kernel), - kernel_context_{kernel_context} { +const SplitKConfig& ComputeContext::GetSplitKConfig() { + return webgpu_context_.GetSplitKConfig(); } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index fdf89854469d6..ed16f2f0a1345 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -24,13 +24,7 @@ namespace webgpu { class WebGpuContext; class BufferManager; -// -// Class ComputeContextBase is designed to provide basic context information -// for running a compute shader program. -// -// An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created. -// -class ComputeContextBase { +class ComputeContext final { public: // Nested accessor class to provide controlled access to BufferManager class BufferManagerAccessor { @@ -40,31 +34,18 @@ class ComputeContextBase { friend class WebGpuContext; private: - static const webgpu::BufferManager& Get(const ComputeContextBase& context); + static const webgpu::BufferManager& Get(const ComputeContext& context); }; - ComputeContextBase(WebGpuContext& webgpu_context, - const WebGpuExecutionProvider& ep, - const OpKernel& op_kernel); - - ~ComputeContextBase() = default; - - // - // Get the node name. - // - inline decltype(auto) NodeName() const { - return op_kernel_.Node().Name(); - } + ComputeContext(OpKernelContext& kernel_context, + const OpKernel& op_kernel, + const WebGpuExecutionProvider& ep, + WebGpuContext& webgpu_context); - // - // Get the operator type. - // - inline decltype(auto) OpType() const { - return op_kernel_.Node().OpType(); - } + ~ComputeContext() = default; // - // Get various information from the WebGPU context. + // Get various information from the context. // inline const wgpu::AdapterInfo& AdapterInfo() const { @@ -76,6 +57,9 @@ class ComputeContextBase { inline bool HasFeature(wgpu::FeatureName feature) const { return webgpu_context_.DeviceHasFeature(feature); } + inline bool IsGraphCaptureEnabled() const { + return ep_.IsGraphCaptureEnabled(); + } #if !defined(__wasm__) inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const { return webgpu_context_.SubgroupMatrixConfigs(); @@ -83,57 +67,17 @@ class ComputeContextBase { #endif // - // Get Split-K configuration. - // - inline const SplitKConfig& GetSplitKConfig() const { - return webgpu_context_.GetSplitKConfig(); - } - - // - // Get whether graph capture is enabled. + // Get the kernel context. // - inline bool IsGraphCaptureEnabled() const { - return ep_.IsGraphCaptureEnabled(); + inline OpKernelContext& KernelContext() { + return kernel_context_; } // // Get the logger. // inline const logging::Logger& Logger() const { - return *ep_.GetLogger(); - } - - // - // Run a compute shader program. - // - inline Status RunProgram(const ProgramBase& program) { - return webgpu_context_.Run(*this, program); - } - - protected: - WebGpuContext& webgpu_context_; - const WebGpuExecutionProvider& ep_; - const OpKernel& op_kernel_; -}; - -// -// Class ComputeContext provides all information a `ComputeContextBase` provides, and also -// access to `OpKernelContext` for input and output tensors. -// -class ComputeContext final : public ComputeContextBase { - public: - ComputeContext(WebGpuContext& webgpu_context, - const WebGpuExecutionProvider& ep, - const OpKernel& op_kernel, - OpKernelContext& kernel_context); - - ~ComputeContext() = default; - - // - // Get the kernel context. - // - inline OpKernelContext& KernelContext() { - return kernel_context_; + return kernel_context_.Logger(); } // @@ -201,8 +145,25 @@ class ComputeContext final : public ComputeContextBase { return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst); } + // + // Run a compute shader program. + // + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } + + // + // Get Split-K configuration. + // + // `split_k_config_` won't be initialized until the first call to this method. + // + const SplitKConfig& GetSplitKConfig(); + private: + WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; + const OpKernel& op_kernel_; + const WebGpuExecutionProvider& ep_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 3c974ef5133c0..82645e30082e6 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -322,14 +322,11 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { round_str = "round"; } - std::string use_pow_shortcut; + std::string use_sqrt_for_pow; if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { - // use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0 // use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5 - use_pow_shortcut = - " else if (b == 2.0) {\n" - " return a * a;\n" - " } else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" + use_sqrt_for_pow = + " else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" " return sqrt(a);\n" " }\n"; } @@ -340,7 +337,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { " } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n" " return input_a_element_t(pow(f32(a), b)); // NaN\n" " }\n" - << use_pow_shortcut + << use_sqrt_for_pow << " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n" << "}\n" "fn pow_v(a : vec4, b : vec4) -> vec4 {\n" diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index c26b58a7af1f4..6aefa90a59285 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -93,21 +93,18 @@ Status ApplyGemmPacked(const Tensor* a, } const uint32_t TILE_SIZE = 32; - const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE; program.CacheHint(alpha, transA, transB, c_is_scalar) .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) - .SetDispatchGroupSize(dispatch_x, dispatch_y, 1u) + .SetDispatchGroupSize(num_tile_n, num_tile_m, 1) .SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z) .AddUniformVariables({{alpha}, {beta}, - {M}, /* dim_a_outer */ - {N}, /* dim_b_outer */ - {K}, /*dim_inner */ - {dispatch_x}, /* logical_dispatch_x */ - {dispatch_y}, /* logical_dispatch_y */ - {1u}} /* logical_dispatch_z */ + {M}, /* dim_a_outer */ + {N}, /* dim_b_outer */ + {K}} /*dim_inner */ ); return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.h b/onnxruntime/core/providers/webgpu/math/gemm_packed.h index cb89ccefba313..dce5164693aa8 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.h @@ -32,10 +32,7 @@ class GemmProgram final : public Program { {"beta", ProgramUniformVariableDataType::Float32}, {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}); constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 89718149cea88..7cbc7f6a4a821 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -117,20 +117,6 @@ void HandleMatMulWithSplitK( } } -// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in -// `ProgramBase.SetDispatchGroupSize()` may be normalized in -// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use -// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. -void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) { - shader.MainFunctionBody() - << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" - << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" - << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" - << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" - << " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n" - << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; -} - } // namespace void MatMulReadFnSource(ShaderHelper& shader, @@ -288,22 +274,20 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << "const innerElementSize = " << inner_elements_size << ";\n" << "const tileInner = " << tile_inner << ";\n"; - InitializeLogicalWorkgroupIDAndGlobalID(shader); - shader.MainFunctionBody() << " let localRow = i32(local_id.y);\n" << " let tileRow = localRow * rowPerThread;\n" << " let tileCol = i32(local_id.x);\n" - << " let globalRow = i32(logical_global_id.y) * rowPerThread;\n" - << " let globalCol = i32(logical_global_id.x);\n" - << " let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" - << " let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" + << " let globalRow = i32(global_id.y) * rowPerThread;\n" + << " let globalCol = i32(global_id.x);\n" + << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" + << " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" << " var acc: array, rowPerThread>;\n"; if (split_k) { // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from - // `kSplitK * i32(logical_global_id.z)`. + // `kSplitK * i32(global_id.z)`. // // For example: considering computing Y = (X * W + B) in one workgroup. // Let kSplitK = 2, B = [d1, d2] @@ -321,15 +305,15 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) // Workgroup3: compute (C1 * C2) // In each workgroup: - // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `logical_global_id.z` + // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z` // - When the computation in each workgroup is completed, add the result to Y with several // atomic built-in functions in `HandleMatMulWithSplitK()`. shader.MainFunctionBody() << "const kSplitK = " << split_dim_inner << ";\n" << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" - << " var kStart = kSplitK * i32(logical_global_id.z);\n" + << " var kStart = kSplitK * i32(global_id.z);\n" - // When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate + // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate // the index of split-k instead of batch. << " let batch = 0;\n" << " let batchIndices = 0u;\n"; @@ -337,7 +321,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, shader.MainFunctionBody() << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" - << " let batch = i32(logical_global_id.z);\n" + << " let batch = i32(global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : ""); } @@ -514,9 +498,7 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, << "const colPerThread = " << elements_per_thread_x << ";\n" << "const tileInner = " << tile_inner << ";\n"; - InitializeLogicalWorkgroupIDAndGlobalID(shader); - - shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n" + shader.MainFunctionBody() << " let batch = i32(global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" @@ -525,10 +507,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, shader.MainFunctionBody() << "let tileRow = i32(local_id.y) * rowPerThread;\n" << "let tileCol = i32(local_id.x) * colPerThread;\n" - << "let globalRow = i32(logical_global_id.y) * rowPerThread;\n" - << "let globalCol = i32(logical_global_id.x) * colPerThread;\n" - << "let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" - << "let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" + << "let globalRow = i32(global_id.y) * rowPerThread;\n" + << "let globalCol = i32(global_id.x) * colPerThread;\n" + << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" + << "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 72dd235eb820a..55c2c5773cc1f 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -256,6 +256,8 @@ Status ComputeMatMul(ComputeContext* context, // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the // number of splits along `dim_inner`. + // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize + // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`. split_dim_inner = split_k_config.GetSplitDimInner(); dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; @@ -269,7 +271,7 @@ Status ComputeMatMul(ComputeContext* context, .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) - .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z) diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index dbd193bc38f58..143ba61c99e13 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -24,10 +24,7 @@ class MatMulProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}); bool NeedSplitK() const; diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc index bf3bb53341418..2f34aa21c8309 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -64,7 +64,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { int components = input.NumComponents(); const std::string thread_max_decl = is_fp32_ - ? "var thread_max = x_value_t(-3.4028234663852886e+38f);\n" + ? "var thread_max = x_value_t(-3.402823e+38f);\n" : "var thread_max = x_value_t(-65504.0h);\n"; // Define shared memory for row max and row sum diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 4fff736fd2f32..77fa46cb87518 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -216,46 +216,6 @@ Status Conv::ComputeInternal(ComputeContext& context return context.RunProgram(conv2d_mm_program); } -template -Status Conv::PrePackInternal(ComputeContextBase& /* context */, - const Tensor& tensor, - int input_idx, - AllocatorPtr /* alloc */, - /*out*/ bool& is_packed) { - is_packed = false; - - if constexpr (is_channels_last) { - if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) { - // only deal with 4D NHWC weights - - // TODO: implement weight transpose for pre-pack here - // Conv::ComputeInternal() should be updated to reflect the change: - // - if the initializer is packed, `context.Input(1)` will be nullptr. - // - in this case, use `transposed_kernel_` instead. - - // // Step.1 - calculate transposed weight shape - // TensorShape transposed_kernel_shape{tensor.Shape()[2], - // tensor.Shape()[3], - // tensor.Shape()[1], - // tensor.Shape()[0]}; - - // // Step.2 - create transposed weight tensor - // transposed_kernel_ = std::make_unique(tensor.DataType(), transposed_kernel_shape, alloc); - - // // Step.3 - do transpose - // size_t perm[] = {2, 3, 1, 0}; - // ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, - // perm, - // tensor, - // *transposed_kernel_)); - - // is_packed = true; // set this flag to true so that ORT will release the initializer tensor - } - } - - return Status::OK(); -} - // Explicit template instantiation for FusedConv template class Conv; template class Conv; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h index 5bf94a459a44a..cafaa272c0613 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.h +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -23,16 +23,9 @@ class Conv : public WebGpuKernel { } Status ComputeInternal(ComputeContext& context) const override; - Status PrePackInternal(ComputeContextBase& context, - const Tensor& tensor, - int input_idx, - AllocatorPtr alloc, - /*out*/ bool& is_packed) override; - protected: ConvAttributes conv_attrs_; Activation activation_; - std::unique_ptr transposed_kernel_; // should only have value when `is_initializer` AND `is_4D` AND `is_NHWC` }; Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc index c66f2cbd582d9..2d5424c52a3f2 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc @@ -226,10 +226,7 @@ Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::v {static_cast(dim_inner)}, {pads}, {strides}, - {dilations}, - {dispatch[0]}, - {dispatch[1]}, - {dispatch[2]}}); + {dilations}}); return program; } diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h index e161bffb0c503..d7cc08aae26f3 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h @@ -38,10 +38,7 @@ class Conv2dMMProgram final : public Program { {"dim_inner", ProgramUniformVariableDataType::Uint32}, {"pads", ProgramUniformVariableDataType::Uint32}, {"strides", ProgramUniformVariableDataType::Uint32}, - {"dilations", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); + {"dilations", ProgramUniformVariableDataType::Uint32}); private: const Activation& activation_; diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 5f59fecc425e2..7e8b434431781 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -92,28 +92,14 @@ Status SliceProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -static std::vector getInt64Input(const Tensor* tensor) { - if (tensor->IsDataType()) { - return std::vector(tensor->DataAsSpan().begin(), tensor->DataAsSpan().end()); - } - ORT_ENFORCE(tensor->IsDataType(), "Expected tensor of type int32 or int64"); - std::vector result; - auto span = tensor->DataAsSpan(); - result.reserve(span.size()); - for (auto v : span) { - result.push_back(static_cast(v)); - } - return result; -} - Status Slice::ComputeInternal(ComputeContext& context) const { // READ INPUTS const Tensor* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); auto input_rank = input_shape.NumDimensions(); - auto starts_raw = attr_starts_.empty() ? getInt64Input(context.Input(1)) : attr_starts_; - auto ends_raw = attr_ends_.empty() ? getInt64Input(context.Input(2)) : attr_ends_; + auto starts_raw = attr_starts_.empty() ? context.Input(1)->DataAsSpan() : gsl::make_span(attr_starts_); + auto ends_raw = attr_ends_.empty() ? context.Input(2)->DataAsSpan() : gsl::make_span(attr_ends_); ORT_ENFORCE(starts_raw.size() == ends_raw.size(), "starts and ends must have the same size"); @@ -140,7 +126,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { axes_default.push_back(i); } } - auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? axes_default : getInt64Input(axes_tensor)) : attr_axes_; + auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? gsl::make_span(axes_default) : axes_tensor->DataAsSpan()) : gsl::make_span(attr_axes_); std::vector steps_default; if (steps_tensor == nullptr) { @@ -149,7 +135,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { steps_default.push_back(1); } } - auto steps_raw = steps_tensor == nullptr ? steps_default : getInt64Input(steps_tensor); + auto steps_raw = steps_tensor == nullptr ? gsl::make_span(steps_default) : steps_tensor->DataAsSpan(); // get final axes std::vector axes, axes_fixed; diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 5415d4a5ead5b..cec321d0da80e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -108,7 +108,7 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, +Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output) { const auto& input_shape = input.Shape(); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index 5e9ccc6750cd6..b62a419fa12bc 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase { Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { } Status ComputeInternal(ComputeContext& context) const override; - static Status DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output); + static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output); constexpr static uint32_t TILE_SIZE = 16; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index b8d5adc421124..28decb076951e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -147,9 +147,6 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // create program manager program_mgr_ = std::make_unique(*this); - // create split-k config - split_k_config_ = std::make_unique(adapter_info_); - // set query type #if !defined(__wasm__) if (DeviceHasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { @@ -181,7 +178,7 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } -Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) { +Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); @@ -291,8 +288,8 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch); if (is_profiling_) { - PendingKernelInfo pending_kernel_info(context.NodeName(), - context.OpType(), + PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(), + context.KernelContext().GetOpType(), program.Name(), key, inputs, @@ -445,7 +442,7 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; WGPUBuffer uniform_buffer = nullptr; - const webgpu::BufferManager& buffer_mgr = ComputeContextBase::BufferManagerAccessor::Get(context); + const webgpu::BufferManager& buffer_mgr = ComputeContext::BufferManagerAccessor::Get(context); if (uniform_buffer_total_size > 0) { std::vector uniform_data_buffer(uniform_buffer_total_size); @@ -913,6 +910,13 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; std::once_flag WebGpuContextFactory::init_default_flag_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 84dfb47ef4687..bd7dae75f2e2d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,6 +5,7 @@ #include #include +#include #include "core/providers/webgpu/webgpu_external_header.h" @@ -22,7 +23,7 @@ class Tensor; namespace webgpu { class WebGpuContext; -class ComputeContextBase; +class ComputeContext; class ProgramBase; // Definition for CapturedCommandInfo in the webgpu namespace @@ -151,13 +152,6 @@ class WebGpuContext final { return validation_mode_; } - // - // Get Split-K configuration. - // - const SplitKConfig& GetSplitKConfig() const { - return *split_k_config_; - } - void StartProfiling(); void CollectProfilingData(profiling::Events& events); void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); @@ -176,9 +170,16 @@ class WebGpuContext final { // Status PopErrorScope(); - Status Run(ComputeContextBase& context, const ProgramBase& program); + Status Run(ComputeContext& context, const ProgramBase& program); void OnRunEnd(); + // + // Get Split-K configuration. + // + // `split_k_config_` won't be initialized until the first call to this method. + // + const SplitKConfig& GetSplitKConfig(); + private: enum class TimestampQueryType { None = 0, @@ -276,7 +277,7 @@ class WebGpuContext final { uint32_t num_pending_dispatches_ = 0; const uint32_t max_num_pending_dispatches_ = 16; - std::unique_ptr split_k_config_; + std::optional split_k_config_; // profiling TimestampQueryType query_type_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 6b764d51bcf75..e0b84fef51f1f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -794,7 +794,8 @@ using namespace webgpu; WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, WebGpuContext& context, WebGpuExecutionProviderConfig&& config) - : IExecutionProvider{kWebGpuExecutionProvider, WebGpuDevice}, + : IExecutionProvider{kWebGpuExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)}, context_id_{context_id}, context_{context}, preferred_data_layout_{config.data_layout}, @@ -934,14 +935,13 @@ std::unique_ptr WebGpuExecutionProvider::GetEx std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, std::string_view node_op_type, DataLayout target_data_layout) const { - // NHWC for Resize operator is not implemented on kWebGpuExecutionProvider - if (node_domain == kOnnxDomain && node_op_type == "Resize") { - return target_data_layout != DataLayout::NHWC; + if (target_data_layout != DataLayout::NHWC) { + return std::nullopt; } - // WebGPU perfer NCHW for InstanceNormalization due to a better performance - if (node_domain == kOnnxDomain && node_op_type == "InstanceNormalization") { - return target_data_layout != DataLayout::NHWC; + // NHWC for Resize operator is not implemented on kWebGpuExecutionProvider + if (node_domain == kOnnxDomain && node_op_type == "Resize") { + return false; } return std::nullopt; diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc index ea38e9415e1fe..8d6ae6caeaf83 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc @@ -11,58 +11,25 @@ namespace webgpu { WebGpuKernel::WebGpuKernel(const OpKernelInfo& info) : OpKernel(info), - ep_(*static_cast(info.GetExecutionProvider())), - webgpu_context_(WebGpuContextFactory::GetContext(ep_.GetDeviceId())) { + ep_(*static_cast(info.GetExecutionProvider())) { } Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const { - ComputeContext context{webgpu_context_, - ep_, - *this, - *p_op_kernel_context}; + WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId()); + ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context}; - if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { - webgpu_context_.PushErrorScope(); + if (webgpu_context.ValidationMode() >= ValidationMode::Full) { + webgpu_context.PushErrorScope(); } Status s = ComputeInternal(context); - if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { - ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); + if (webgpu_context.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context.PopErrorScope()); } return s; } -Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { - ComputeContextBase context{webgpu_context_, ep_, *this}; - - if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { - webgpu_context_.PushErrorScope(); - } - - // Currently, ORT does not allow using prepacked weights in non-CPU EPs. - // So we do not pass prepacked_weights to PrePackInternal. - // Kernel implementation that supports prepacking should manage its own storage. - - Status s = PrePackInternal(context, tensor, input_idx, alloc, is_packed); - - if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { - ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); - } - - return s; -} - -Status WebGpuKernel::PrePackInternal(ComputeContextBase& /*context*/, - const Tensor& /*tensor*/, - int /*input_idx*/, - AllocatorPtr /*alloc*/, - /*out*/ bool& is_packed) { - is_packed = false; - return Status::OK(); -} - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 2c57991c6ee35..3c750e305421c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -23,41 +23,8 @@ class WebGpuKernel : public OpKernel { virtual Status ComputeInternal(ComputeContext& context) const = 0; - // Overrides OpKernel::PrePack to handle constant tensor pre-processing for WebGPU kernels. - // This method creates a ComputeContextBase and delegates to PrePackInternal. - // - // NOTE: Currently, ORT does not allow using prepacked weights in non-CPU EPs, so the - // prepacked_weights parameter is not passed to PrePackInternal. Kernel implementations - // that support prepacking should manage their own storage. - Status PrePack(const Tensor& tensor, - int input_idx, - AllocatorPtr alloc, - /*out*/ bool& is_packed, - /*out*/ PrePackedWeights* prepacked_weights) override; - - // Virtual method that allows derived kernels to pre-process constant tensors during initialization. - // - // This method is called during kernel initialization when constant tensors are available, - // allowing kernels to perform operations like tensor transposition or format conversion - // before the first Compute call. - // - // @param context The WebGPU compute context base providing access to the execution environment. - // @param tensor The constant tensor to potentially pre-process. - // @param input_idx The index of this input in the kernel's input list. - // @param alloc The allocator to use for any new tensor allocations. - // @param is_packed Output parameter. Set to true if the tensor was pre-packed/processed, - // false otherwise. The default implementation sets this to false. - // - // @return Status::OK() on success, or an error status on failure. - virtual Status PrePackInternal(ComputeContextBase& context, - const Tensor& tensor, - int input_idx, - AllocatorPtr alloc, - /*out*/ bool& is_packed); - private: const WebGpuExecutionProvider& ep_; - WebGpuContext& webgpu_context_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 5fd24b2bff037..568d29a96cb88 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -21,24 +21,27 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components return TensorShape(shape_vector); } -SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) { +SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) { + SplitKConfig config = {}; + if (adapter_info.vendor == std::string_view{"intel"}) { if (adapter_info.architecture == std::string_view{"xe-2lpg"} || adapter_info.architecture == std::string_view{"xe-2hpg"} || adapter_info.architecture == std::string_view{"xe-lpg"} || adapter_info.architecture == std::string_view{"gen-12hp"}) { - enable_split_k_ = true; + config.enable_split_k_ = true; // Below thresholds are only verified on the above Intel GPUs without any regressions. The // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more // atomic calls for each output value. - split_dim_inner_ = 256; - min_dim_inner_with_split_k_ = split_dim_inner_ * 2; - max_dim_inner_with_split_k_ = split_dim_inner_ * 9; - max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; + config.split_dim_inner_ = 256; + config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; + config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; + config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; } } + return config; } bool SplitKConfig::UseSplitK( diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 7d5ab5fea8006..d45b9bf4dd119 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -91,12 +91,9 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; } -/** - * Configuration for Split-K optimization (Conv|MatMul). - */ class SplitKConfig { public: - explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); + static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index ab3932e7abfb4..4d4dea9cb444c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2943,8 +2943,6 @@ Status InferenceSession::Run(const RunOptions& run_options, << cached_execution_provider_for_graph_replay_.Type() << " CUDA Graph for this model with tag: " << run_options.run_tag << " with graph annotation id: " << graph_annotation_id; - // log evaluation start to trace logging provider - env.GetTelemetryProvider().LogEvaluationStart(session_id_); ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id)); } else { InlinedVector exec_providers_to_stop; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 4cb21b80109c8..6189e6ca7f012 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -404,7 +404,6 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model session))); } - Env::Default().GetTelemetryProvider().LogCompileModel(session->GetCurrentSessionId()); ORT_RETURN_IF_ERROR(ToStatusAndRelease(InitializeSession(session_options, *session))); return Status::OK(); } diff --git a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc index 5deef01cd783e..70c7a5b2bcdcb 100644 --- a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc +++ b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc @@ -22,17 +22,10 @@ namespace test { // --------- Helpers --------- -// cuda errors are sticky and may affect subsequent API calls. -// we want to clear the error if when supported check fails. -void ClearCudaError() { - ORT_IGNORE_RETURN_VALUE(::cudaGetLastError()); -} - static bool IsCudaMemPoolSupported() { int ort_cuda_rt_version = 0; cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version); if (cuda_status != cudaSuccess) { - ClearCudaError(); return false; } @@ -43,7 +36,6 @@ static bool IsCudaMemPoolSupported() { int ort_cuda_driver_version = 0; cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version); if (cuda_status != cudaSuccess) { - ClearCudaError(); return false; } @@ -73,10 +65,9 @@ static bool IsCudaMemPoolSupported() { cudaMemPool_t pool; auto cuda_error = cudaMemPoolCreate(&pool, &props); if (cuda_error != cudaSuccess) { - ClearCudaError(); return false; } - ORT_IGNORE_RETURN_VALUE(cudaMemPoolDestroy(pool)); + cuda_error = cudaMemPoolDestroy(pool); return true; } @@ -89,9 +80,7 @@ static ::cudaStream_t NewCudaStream() { } static void DestroyCudaStream(::cudaStream_t s) { - if (s) { - EXPECT_EQ(cudaSuccess, ::cudaStreamDestroy(s)); - } + if (s) (void)::cudaStreamDestroy(s); } static void TouchDevice(void* p, size_t bytes, ::cudaStream_t s, unsigned char value = 0xAB) { diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index af9706855ee3c..d8cc56d738175 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -203,48 +203,6 @@ TEST_P(TypeTests, IOTypes) { } } -TEST(NvExecutionProviderTest, TestSessionOutputs) { - /* - * Model #1: - * - * "input" ---> TopK --- - * |---> "scores" - * |--- Less ---> "Less_output_0" - * |--- Div ---> "Div_output_0" - * |--- Mod ---> "labels" - */ - { - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - - auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); - Ort::Session session(*ort_env, model_path, session_options); - - size_t output_count = session.GetOutputCount(); - ASSERT_TRUE(output_count == 4); - } - - /* - * Model #2: - * - * "X" ---> Dropout ---> MatMul ---> "Y" - * ^ | - * | | - * "W" ------ ----> Can't be graph's output - * - */ - { - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - - auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); - Ort::Session session(*ort_env, model_path, session_options); - - size_t output_count = session.GetOutputCount(); - ASSERT_TRUE(output_count == 1); - } -} - INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, ::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, diff --git a/onnxruntime/test/providers/qnn/README.md b/onnxruntime/test/providers/qnn/README.md deleted file mode 100644 index c3d0c720a1aa4..0000000000000 --- a/onnxruntime/test/providers/qnn/README.md +++ /dev/null @@ -1,70 +0,0 @@ -# ONNX Runtime QNN Execution Provider Tests -## Overview -1. The `onnxruntime/test/providers/qnn` directory contains integration tests for the Qualcomm Neural Network (QNN) execution provider. -2. Most testcases run an ONNX model through the QNN-EP, then verifies the inference result against the one on CPU-EP - -## Building the Tests -The tests are built as part of the regular ONNX Runtime build. After a successful build you will have an executable named -- onnxruntime_provider_test.exe (Windows) -- onnxruntime_provider_test (Linux/macOS) - -## Running the Tests -1. QNN supports several backends. You can use the standard Google‑Test syntax for filtering: - - `onnxruntime_provider_test.exe --gtest_filter=QnnCPUBackendTests.*` - - `onnxruntime_provider_test.exe --gtest_filter=QnnHTPBackendTests.*` - - `onnxruntime_provider_test.exe --gtest_filter=QnnGPUBackendTests.*` - - `onnxruntime_provider_test.exe --gtest_filter=QnnIRBackendTests.*` -2. Saving Test Artifacts - - For debugging it is often helpful to keep the intermediate files that the tests generate. The following environment - variables are recognized by the test binary: - - `QNN_DUMP_ONNX`: Saves the input ONNX model used for the test - - `QNN_DUMP_JSON`: Save json qnn graph with provider_option `dump_json_qnn_graph` - - `QNN_DUMP_DLC`: Saves the compiled QNN DLC file by specifying the provider_option `backend_path` to `QnnIr.dll` - - The artifacts will be saved to a directory named with `_` - ``` - . - ├── QnnCPUBackendTests_BatchNorm2D_fp32 # RunQnnModelTest - │ ├── dumped_f32_model.onnx # float32 ONNX model - │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc - │ └── QNNExecutionProvider_QNN_XXXX_X_X.json - ├── QnnHTPBackendTests_BatchNorm_FP16 # TestFp16ModelAccuracy - │ ├── dumped_f16_model.onnx # float16 ONNX model - │ ├── dumped_f32_model.onnx # float32 ONNX model - │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc - │ └── QNNExecutionProvider_QNN_XXXX_X_X.json - └── QnnHTPBackendTests_BatchNorm2D_U8U8S32 # TestQDQModelAccuracy - ├── dumped_f32_model.onnx # float32 ONNX model - ├── dumped_qdq_model.onnx # QDQ ONNX model - ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc - └── QNNExecutionProvider_QNN_XXXX_X_X.json - - # All artifact files are placed under the current working directory from which the test binary is invoked. - ``` -3. Verbose - - `QNN_VERBOSE`: Sets the ONNX Runtime log level to `ORT_LOGGING_LEVEL_VERBOSE` - -4. You can enable any combination of these environment variables, for example: - - On Linux/macOS - ```bash - export QNN_DUMP_ONNX=1 - export QNN_DUMP_JSON=1 - export QNN_DUMP_DLC=1 - export QNN_VERBOSE=1 - ``` - - On Windows - ```cmd - set QNN_DUMP_ONNX=1 - set QNN_DUMP_JSON=1 - set QNN_DUMP_DLC=1 - set QNN_VERBOSE=1 - ``` - ```ps1 - $Env:QNN_DUMP_ONNX = "1" - $Env:QNN_DUMP_JSON = "1" - $Env:QNN_DUMP_DLC = "1" - $Env:QNN_VERBOSE = "1" - ``` - -# Note -- An issue on QNN backends can prevent the test artifacts from being successfully saved. -- The `onnxruntime_provider_test.exe` does not automatically delete the artifact directories, so you may want to prune them after a debugging session. diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 15a9132aaa16c..1c70f4012090e 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -101,12 +101,6 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, std::function* ep_graph_checker) { - std::filesystem::path output_dir; - if (QNNTestEnvironment::GetInstance().dump_onnx() || - QNNTestEnvironment::GetInstance().dump_json() || - QNNTestEnvironment::GetInstance().dump_dlc()) { - output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); - } EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; verification_params.fp32_abs_err = fp32_abs_err; @@ -116,10 +110,6 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); - if (QNNTestEnvironment::GetInstance().verbose()) { - logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); - } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -133,27 +123,7 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); - - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); - } - TryEnableQNNSaver(provider_options); - if (QNNTestEnvironment::GetInstance().dump_dlc()) { - provider_options["dump_qnn_ir_dlc"] = "1"; - provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); -#if defined(_WIN32) - provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; -#else - provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; -#endif // defined(_WIN32) - } - if (QNNTestEnvironment::GetInstance().dump_json()) { - provider_options["dump_json_qnn_graph"] = "1"; - provider_options["json_qnn_graph_dir"] = output_dir.string(); - } RunAndVerifyOutputsWithEP(AsByteSpan(model_data.data(), model_data.size()), "QNN_EP_TestLogID", QnnExecutionProviderWithOptions(provider_options), helper.feeds_, verification_params, @@ -164,21 +134,11 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, logging::Severity log_severity, std::function* ep_graph_checker) { - std::filesystem::path output_dir; - if (QNNTestEnvironment::GetInstance().dump_onnx() || - QNNTestEnvironment::GetInstance().dump_dlc() || - QNNTestEnvironment::GetInstance().dump_json()) { - output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); - } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); - if (QNNTestEnvironment::GetInstance().verbose()) { - logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); - } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -192,27 +152,7 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); - - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); - } - TryEnableQNNSaver(provider_options); - if (QNNTestEnvironment::GetInstance().dump_dlc()) { - provider_options["dump_qnn_ir_dlc"] = "1"; - provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); -#if defined(_WIN32) - provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; -#else - provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; -#endif // defined(_WIN32) - } - if (QNNTestEnvironment::GetInstance().dump_json()) { - provider_options["dump_json_qnn_graph"] = "1"; - provider_options["json_qnn_graph_dir"] = output_dir.string(); - } SessionOptions so; so.session_logid = "QNN_EP_TestLogID"; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 4d4f795d161b1..aeb3a9a114871 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -499,77 +499,6 @@ struct QDQTolerance { float value; }; -class QNNTestEnvironment { - public: - // Delete copy constructor and assignment operator - QNNTestEnvironment(const QNNTestEnvironment&) = delete; - QNNTestEnvironment& operator=(const QNNTestEnvironment&) = delete; - - // Static method to get the singleton instance - static QNNTestEnvironment& GetInstance() { - static QNNTestEnvironment instance; - return instance; - } - - bool dump_onnx() const { return dump_onnx_; } - bool dump_json() const { return dump_json_; } - bool dump_dlc() const { return dump_dlc_; } - bool verbose() const { return verbose_; } - - std::filesystem::path CreateTestcaseDirs() { - std::string test_suite_name = ::testing::UnitTest::GetInstance()->current_test_info()->test_suite_name(); - std::string test_name = ::testing::UnitTest::GetInstance()->current_test_info()->name(); - std::filesystem::path output_dir = std::filesystem::current_path() / (test_suite_name + "_" + test_name); - std::filesystem::create_directories(output_dir); - - return output_dir; - } - - private: - // Private constructor for singleton - QNNTestEnvironment() { - ParseEnvironmentVars(); - } - - // Helper function to check if an environment variable is set - bool IsEnvVarSet(const char* name) { - const char* value = std::getenv(name); - if (value == nullptr) { - return false; - } - - // Consider the variable set if it's not empty and not "0" - return *value != '\0' && *value != '0'; - } - - void ParseEnvironmentVars() { - if (IsEnvVarSet("QNN_DUMP_ONNX")) { - std::cout << "[QNN only] ONNX model dumping enabled via environment variable." << std::endl; - dump_onnx_ = true; - } - - if (IsEnvVarSet("QNN_DUMP_JSON")) { - std::cout << "[QNN only] Json QNN Graph dumping enabled via environment variable." << std::endl; - dump_json_ = true; - } - - if (IsEnvVarSet("QNN_DUMP_DLC")) { - std::cout << "[QNN only] DLC dumping enabled via environment variable." << std::endl; - dump_dlc_ = true; - } - - if (IsEnvVarSet("QNN_VERBOSE")) { - std::cout << "Verbose enabled via environment variable." << std::endl; - verbose_ = true; - } - } - - bool dump_onnx_ = false; - bool dump_json_ = false; - bool dump_dlc_ = false; - bool verbose_ = false; -}; - /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: * @@ -600,21 +529,15 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}, std::function* qnn_ep_graph_checker = nullptr) { - std::filesystem::path output_dir; - if (QNNTestEnvironment::GetInstance().dump_onnx() || - QNNTestEnvironment::GetInstance().dump_dlc() || - QNNTestEnvironment::GetInstance().dump_json()) { - output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); - } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); + + // Uncomment to dump LOGGER() output to stdout. + // logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(log_severity); - if (QNNTestEnvironment::GetInstance().verbose()) { - logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); - } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -628,11 +551,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); - } + // Uncomment to save f32 model to disk for debugging. + // ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, ToPathString("cmp_accuracy.f32.onnx"))); // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; @@ -674,27 +594,11 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(qdq_model.MainGraph().Resolve()); qdq_model.ToProto().SerializeToString(&qdq_model_data); - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_qdq_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx QDQ model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, dump_path)); - } + // Uncomment to save QDQ model to disk for debugging. + // ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, ToPathString("cmp_accuracy.qdq.onnx"))); bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); - if (QNNTestEnvironment::GetInstance().dump_dlc()) { - qnn_options["dump_qnn_ir_dlc"] = "1"; - qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); -#if defined(_WIN32) - qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; -#else - qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; -#endif // defined(_WIN32) - } - if (QNNTestEnvironment::GetInstance().dump_json()) { - qnn_options["dump_json_qnn_graph"] = "1"; - qnn_options["json_qnn_graph_dir"] = output_dir.string(); - } std::vector qnn_qdq_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; @@ -839,21 +743,11 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, logging::Severity log_severity = logging::Severity::kERROR, const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}) { - std::filesystem::path output_dir; - if (QNNTestEnvironment::GetInstance().dump_onnx() || - QNNTestEnvironment::GetInstance().dump_dlc() || - QNNTestEnvironment::GetInstance().dump_json()) { - output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); - } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); - if (QNNTestEnvironment::GetInstance().verbose()) { - logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); - } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -866,12 +760,6 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); - } - // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; InferenceModel(f32_model_data, "f32_model_logger", {}, ExpectedEPNodeAssignment::All, @@ -908,27 +796,8 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f16_model.MainGraph().Resolve()); f16_model.ToProto().SerializeToString(&f16_model_data); - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_f16_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float16 model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(f16_model, dump_path)); - } - bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); - if (QNNTestEnvironment::GetInstance().dump_dlc()) { - qnn_options["dump_qnn_ir_dlc"] = "1"; - qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); -#if defined(_WIN32) - qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; -#else - qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; -#endif // defined(_WIN32) - } - if (QNNTestEnvironment::GetInstance().dump_json()) { - qnn_options["dump_json_qnn_graph"] = "1"; - qnn_options["json_qnn_graph_dir"] = output_dir.string(); - } std::vector qnn_f16_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index dce0d570ec238..6a6545c68cb4f 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -1,6 +1,5 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "onnxruntime_cxx_api.h" #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" @@ -19,8 +18,6 @@ using namespace std; using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; -extern std::unique_ptr ort_env; - namespace onnxruntime { namespace test { @@ -1363,49 +1360,5 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); } - -TEST(TensorrtExecutionProviderTest, TestSessionOutputs) { - /* - * Model #1: - * - * "input" ---> TopK --- - * |---> "scores" - * |--- Less ---> "Less_output_0" - * |--- Div ---> "Div_output_0" - * |--- Mod ---> "labels" - */ - { - OrtTensorRTProviderOptionsV2 provider_options; - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider_TensorRT_V2(provider_options); - - auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); - Ort::Session session(*ort_env, model_path, session_options); - - size_t output_count = session.GetOutputCount(); - ASSERT_TRUE(output_count == 4); - } - - /* - * Model #2: - * - * "X" ---> Dropout ---> MatMul ---> "Y" - * ^ | - * | | - * "W" ------ ----> Can't be graph's output - * - */ - { - OrtTensorRTProviderOptionsV2 provider_options; - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider_TensorRT_V2(provider_options); - - auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); - Ort::Session session(*ort_env, model_path, session_options); - - size_t output_count = session.GetOutputCount(); - ASSERT_TRUE(output_count == 1); - } -} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/node_output_not_used.onnx b/onnxruntime/test/testdata/node_output_not_used.onnx deleted file mode 100644 index e2726182fddc2..0000000000000 Binary files a/onnxruntime/test/testdata/node_output_not_used.onnx and /dev/null differ diff --git a/onnxruntime/test/testdata/node_output_not_used.py b/onnxruntime/test/testdata/node_output_not_used.py deleted file mode 100644 index d36d5e9cfd2f8..0000000000000 --- a/onnxruntime/test/testdata/node_output_not_used.py +++ /dev/null @@ -1,43 +0,0 @@ -import onnx -from onnx import TensorProto, helper - - -def create_model_with_node_output_not_used(model_path): - # Create graph - x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) - w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 3]) - y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) - - # Dropout node (two outputs) - dropout_node = helper.make_node( - "Dropout", - inputs=["X"], - outputs=["dropout_out", "dropout_mask"], - name="DropoutNode", - ) - - # MatMul node - matmul_node = helper.make_node( - "MatMul", - inputs=["dropout_out", "W"], - outputs=["Y"], - name="MatMulNode", - ) - - graph = helper.make_graph( - nodes=[dropout_node, matmul_node], - name="DropoutMatMulGraph", - inputs=[x, w], - outputs=[y], - ) - - model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)]) - - onnx.checker.check_model(model) - onnx.save(model, model_path) - - print(f"Model saved to: {model_path}") - - -if __name__ == "__main__": - create_model_with_node_output_not_used("node_output_not_used.onnx") diff --git a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx deleted file mode 100644 index 340c3d420d574..0000000000000 Binary files a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx and /dev/null differ diff --git a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py deleted file mode 100644 index 232abb2ed9163..0000000000000 --- a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py +++ /dev/null @@ -1,78 +0,0 @@ -import onnx -from onnx import TensorProto, helper - - -def create_model_with_topk_graph_output(model_path): - # ====================== - # ---- Inputs ---- - # ====================== - input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["N"]) - - # ====================== - # ---- Initializers ---- - # ====================== - k = helper.make_tensor("K", TensorProto.INT64, dims=[1], vals=[300]) - zero = helper.make_tensor("zero", TensorProto.INT64, dims=[], vals=[0]) - twenty_six = helper.make_tensor("twenty_six", TensorProto.INT64, dims=[], vals=[26]) - - # ====================== - # ---- Nodes ---- - # ====================== - topk_node = helper.make_node( - "TopK", - inputs=["input", "K"], - outputs=["scores", "topk_indices"], - name="TopK", - ) - - less_node = helper.make_node( - "Less", - inputs=["topk_indices", "zero"], - outputs=["Less_output_0"], - name="Less", - ) - - div_node = helper.make_node( - "Div", - inputs=["topk_indices", "twenty_six"], - outputs=["Div_17_output_0"], - name="Div", - ) - - mod_node = helper.make_node( - "Mod", - inputs=["topk_indices", "twenty_six"], - outputs=["labels"], - name="Mod", - ) - - # ========================= - # ---- Graph Outputs ---- - # ========================= - scores_out = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["K"]) - less_out = helper.make_tensor_value_info("Less_output_0", TensorProto.BOOL, ["K"]) - div_out = helper.make_tensor_value_info("Div_17_output_0", TensorProto.INT64, ["K"]) - labels_out = helper.make_tensor_value_info("labels", TensorProto.INT64, ["K"]) - - # ====================== - # ---- Graph ---- - # ====================== - graph = helper.make_graph( - nodes=[topk_node, less_node, div_node, mod_node], - name="TopKGraph", - inputs=[input_tensor], - outputs=[scores_out, less_out, div_out, labels_out], - initializer=[k, zero, twenty_six], - ) - - model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)]) - - # Validate + Save - onnx.checker.check_model(model) - onnx.save(model, model_path) - - print(f"Model saved to: {model_path}") - - -if __name__ == "__main__": - create_model_with_topk_graph_output("topk_and_multiple_graph_outputs.onnx") diff --git a/tools/ci_build/github/linux/build_rocm_c_api_package.sh b/tools/ci_build/github/linux/build_rocm_c_api_package.sh new file mode 100755 index 0000000000000..3ea90c73342a5 --- /dev/null +++ b/tools/ci_build/github/linux/build_rocm_c_api_package.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +set -e -u -x + +usage() { echo "Usage: $0 -S -B -V [-H ] " 1>&2; exit 1; } + +ROCM_HOME=/opt/rocm + +while getopts S:B:V:H:I:P: parameter_Option; do + case "${parameter_Option}" in + S) SOURCE_DIR=${OPTARG};; + B) BINARY_DIR=${OPTARG};; + V) ROCM_VERSION=${OPTARG};; + H) ROCM_HOME=${OPTARG};; + I) IMAGE=${OPTARG};; + P) PYTHON_BIN=${OPTARG};; + *) usage ;; + esac +done + +EXIT_CODE=1 + +docker run -e SYSTEM_COLLECTIONURI --rm \ + --security-opt seccomp=unconfined \ + --shm-size=1024m \ + --user $UID:$(id -g $USER) \ + -e NIGHTLY_BUILD \ + --volume $SOURCE_DIR:/onnxruntime_src \ + --volume $BINARY_DIR:/build \ + --volume /data/models:/build/models:ro \ + --volume /data/onnx:/data/onnx:ro \ + --workdir /onnxruntime_src \ + $IMAGE \ + /bin/bash -c "${PYTHON_BIN:-python} /onnxruntime_src/tools/ci_build/build.py --config Release --build_dir /build --parallel --use_rocm --use_binskim_compliant_compile_flags --rocm_version=$ROCM_VERSION --rocm_home $ROCM_HOME --nccl_home $ROCM_HOME --build_shared_lib --skip_submodule_sync --skip_tests --cmake_extra_defines FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER && cd /build/Release && make install DESTDIR=/build/installed" + + +EXIT_CODE=$? + +set -e +exit $EXIT_CODE diff --git a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh new file mode 100755 index 0000000000000..0be64d96f3a34 --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh @@ -0,0 +1,43 @@ +#!/bin/bash +set -e -x + +# version +ROCM_VERSION=6.2.3 + +while getopts "r:" parameter_Option +do case "${parameter_Option}" +in +r) ROCM_VERSION=${OPTARG};; +esac +done + +tee /etc/yum.repos.d/amdgpu.repo <