diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 7f7ff74959d52..f12eadc2ce794 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -27,7 +27,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -112,7 +112,7 @@ jobs: android_nnapi_ep: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Use jdk 17 uses: actions/setup-java@v5 @@ -187,7 +187,7 @@ jobs: name: Android CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Use jdk 17 uses: actions/setup-java@v5 diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index 30f832f67c5ee..ddf4a52a0ccb0 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -12,7 +12,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Check out a copy of the repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@2.0.0 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index d33e4d923a0bc..1db84400c272a 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 04177b11e9c30..d8f13d13d3f88 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -15,7 +15,7 @@ jobs: name: "Validation" runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: gradle/actions/wrapper-validation@v5 concurrency: group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} diff --git a/.github/workflows/ios.yml b/.github/workflows/ios.yml index 0d2046b980783..ed572aa339ce9 100644 --- a/.github/workflows/ios.yml +++ b/.github/workflows/ios.yml @@ -20,7 +20,7 @@ jobs: runs-on: macos-14 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5aaab5f8e1a10..5c618dc5787a5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: name: Optional Lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: misspell # Check spellings as well uses: reviewdog/action-misspell@v1 with: @@ -42,7 +42,7 @@ jobs: contents: read security-events: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 with: @@ -87,7 +87,7 @@ jobs: name: Optional Lint C++ runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Update PATH run: | echo "$HOME/.local/bin" >> "$GITHUB_PATH" @@ -116,7 +116,7 @@ jobs: name: Lint JavaScript runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-node@v6 with: node-version: 20 diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 2370c631b7a7a..5763b9c39bcc6 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -49,7 +49,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: recursive diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index 886705471b7de..e7e3be8c5f9ed 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -48,7 +48,7 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index af86975ee6cdc..4d9579a746892 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -28,7 +28,7 @@ jobs: packages: write steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -65,7 +65,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -122,7 +122,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -156,7 +156,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -188,7 +188,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -222,7 +222,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -286,7 +286,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -363,7 +363,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -430,7 +430,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -505,7 +505,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 0e26576829e94..47b7c1ba7e889 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -48,7 +48,7 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 # --- Build the Docker image needed for testing --- - name: Build Docker Image for Testing diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index e545406d8d20f..8ba87bc1f731c 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -76,7 +76,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' @@ -124,7 +124,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index 329584c68d7d1..8e1d0264496f6 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -75,7 +75,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index abe627f4ff7bc..7ca330742f69b 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -24,7 +24,7 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 with: diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index 25b7899584bbf..d9fb72271967f 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate C/C++ API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install doxygen and dependencies run: | sudo apt update diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index 34b9c1af9552f..dd55bbd917337 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -24,7 +24,7 @@ jobs: env: DOCFXVERSION: 2.62.2 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install DocFX run: | dotnet tool update -g docfx diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 656d0627ed17d..81defeae518a3 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Java docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Set up JDK 11 uses: actions/setup-java@v5 with: diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index e71d3b3c57a4b..9da78d7d9ed9c 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate JS API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Node.js uses: actions/setup-node@v6 with: diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index 983d3d478a49d..a73b62eba6050 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Objective-C API docs runs-on: macos-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index 389d1683fb1ff..e35e6a04adbef 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate Python API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install tools run: | sudo apt-get update diff --git a/.github/workflows/react_native.yml b/.github/workflows/react_native.yml index 343186b1aec8c..4a56dfbd35406 100644 --- a/.github/workflows/react_native.yml +++ b/.github/workflows/react_native.yml @@ -20,7 +20,7 @@ jobs: aar_path: ${{ runner.temp }}/.artifacts steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -75,7 +75,7 @@ jobs: run: echo "ANDROID_AVD_HOME=${{ runner.temp }}/android-avd" >> $GITHUB_ENV - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Use Python 3.12 uses: actions/setup-python@v6 @@ -175,7 +175,7 @@ jobs: timeout-minutes: 120 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Use Xcode 15.3.0 run: sudo xcode-select --switch /Applications/Xcode_15.3.0.app/Contents/Developer @@ -218,7 +218,7 @@ jobs: timeout-minutes: 90 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Download iOS pod artifact uses: actions/download-artifact@v6 diff --git a/.github/workflows/reusable_linux_build.yml b/.github/workflows/reusable_linux_build.yml index 795e35b06bfb0..f0da87647b8b0 100644 --- a/.github/workflows/reusable_linux_build.yml +++ b/.github/workflows/reusable_linux_build.yml @@ -75,7 +75,7 @@ jobs: id-token: write steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Set up Python ${{ inputs.python_version }} if: inputs.architecture != 'arm64' diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml index 016feab5e0d94..6ae25ccc0bf3e 100644 --- a/.github/workflows/web.yml +++ b/.github/workflows/web.yml @@ -22,7 +22,7 @@ jobs: commit_sha: ${{ steps.extract_commit.outputs.commit_sha }} steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: true diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index eee98332056f6..c16ce6eb222eb 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_build_x64_asan.yml b/.github/workflows/windows_build_x64_asan.yml index 05fd4acd4de9a..ac5f08717155f 100644 --- a/.github/workflows/windows_build_x64_asan.yml +++ b/.github/workflows/windows_build_x64_asan.yml @@ -19,7 +19,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index fd5b65eb039a3..5d6e9b1da31a2 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -21,7 +21,7 @@ jobs: name: Windows GPU CUDA CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' @@ -152,7 +152,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index e8ee7751348b4..0abf6b650f986 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -27,7 +27,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 # Fetch all history for all tags and branches submodules: 'none' diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index b608c0879aa45..537ff1fb00071 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml index 4f0b50e65df6e..f6176164354bb 100644 --- a/.github/workflows/windows_qnn_x64.yml +++ b/.github/workflows/windows_qnn_x64.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index 229efb01f0018..4a564a3b1cb36 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -21,7 +21,7 @@ jobs: name: Windows GPU TensorRT CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' @@ -157,7 +157,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 899a8b66eac7a..f729cda5ea576 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -34,7 +34,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none @@ -156,7 +156,7 @@ jobs: timeout-minutes: 300 steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none @@ -209,7 +209,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index d62c7130e0ebb..385d03c1a6705 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index a2991bb0f1131..ee045b70b6efa 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index bb6c5035b0dce..25dfc41e6922c 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index 4378231338673..e738db262f3a2 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index b453cd570ac05..5672e4043c624 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index d20778d56f60b..381d9dda5cd42 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm deleted file mode 100644 index aca8c3feaff71..0000000000000 --- a/dockerfiles/Dockerfile.rocm +++ /dev/null @@ -1,24 +0,0 @@ -# -------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------- -# Dockerfile to run ONNXRuntime with ROCm integration -#-------------------------------------------------------------------------- - -FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 - -ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime -ARG ONNXRUNTIME_BRANCH=main - -WORKDIR /code - -ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} - -# Prepare onnxruntime repository & build onnxruntime -RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ - /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ - cd onnxruntime &&\ - /bin/sh ./build.sh --allow_running_as_root --config Release --build_wheel --update --build --parallel --cmake_extra_defines\ - ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm &&\ - pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ - cd .. diff --git a/dockerfiles/README.md b/dockerfiles/README.md index 4c69098103edd..88c542b63ccd2 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -1,9 +1,8 @@ # Dockerfiles **Execution Providers** - CPU: [Dockerfile](Dockerfile.source), [Instructions](#cpu) -- CUDA/cuDNN: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) +- CUDA: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) - MIGraphX: [Dockerfile](Dockerfile.migraphx), [Instructions](#migraphx) -- ROCm: [Dockerfile](Dockerfile.rocm), [Instructions](#rocm) - OpenVINO: [Dockerfile](Dockerfile.openvino), [Instructions](#openvino) - TensorRT: [Dockerfile](Dockerfile.tensorrt), [Instructions](#tensorrt) - VitisAI: [Dockerfile](Dockerfile.vitisai) @@ -304,17 +303,3 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-migraphx ``` - - ## ROCm -**Ubuntu 22.04, ROCm6.2.3** - -1. Build the docker image from the Dockerfile in this repository. - ``` - docker build -t onnxruntime-rocm -f Dockerfile.rocm . - ``` - -2. Run the Docker image - - ``` - docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-rocm - ``` diff --git a/dockerfiles/scripts/install_rocm_deps.sh b/dockerfiles/scripts/install_rocm_deps.sh deleted file mode 100644 index fd445be87479b..0000000000000 --- a/dockerfiles/scripts/install_rocm_deps.sh +++ /dev/null @@ -1,84 +0,0 @@ -#!/bin/bash -prefix=/opt/rocm -DEBIAN_FRONTEND=noninteractive -apt-get update && apt-get install -y --no-install-recommends \ - wget \ - zip \ - ca-certificates \ - build-essential \ - curl \ - libcurl4-openssl-dev \ - libssl-dev \ - python3-dev - -# rocm-cmake -rocm_cmake_version=4.5.2 -wget --quiet https://github.com/RadeonOpenCompute/rocm-cmake/archive/refs/tags/rocm-${rocm_cmake_version}.tar.gz -tar -xzvf rocm-${rocm_cmake_version}.tar.gz -rm rocm-${rocm_cmake_version}.tar.gz -cd rocm-cmake-rocm-${rocm_cmake_version} -mkdir build -cd build -cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocm-cmake-rocm-${rocm_cmake_version} - -# rccl -rccl_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rccl/archive/refs/tags/rocm-${rccl_version}.tar.gz -tar -xzvf rocm-${rccl_version}.tar.gz -rm rocm-${rccl_version}.tar.gz -cd rccl-rocm-${rccl_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rccl-rocm-${rccl_version} - -#rocrand -rocrand_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rocRAND/archive/refs/tags/rocm-${rocrand_version}.tar.gz -tar -xzvf rocm-${rocrand_version}.tar.gz -rm rocm-${rocrand_version}.tar.gz -cd rocRAND-rocm-${rocrand_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocRAND-rocm-${rocrand_version} - -#hipcub -hipcub_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/hipCUB/archive/refs/tags/rocm-${hipcub_version}.tar.gz -tar -xzvf rocm-${hipcub_version}.tar.gz -rm rocm-${hipcub_version}.tar.gz -cd hipCUB-rocm-${hipcub_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make package -make install -cd ../.. -rm -rf hipCUB-rocm-${hipcub_version} - -#rocprim -rocprim_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/refs/tags/rocm-${rocprim_version}.tar.gz -tar -xzvf rocm-${rocprim_version}.tar.gz -rm rocm-${rocprim_version}.tar.gz -cd rocPRIM-rocm-${rocprim_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocPRIM-rocm-${rocprim_version} - diff --git a/js/package-lock.json b/js/package-lock.json index 1e9f5cb29fe6c..0fca515b61238 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4,6 +4,7 @@ "requires": true, "packages": { "": { + "name": "js", "license": "MIT", "devDependencies": { "@eslint/compat": "^1.4.0", @@ -3230,6 +3231,27 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "dev": true, + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -3242,6 +3264,32 @@ "node": ">=10.13.0" } }, + "node_modules/glob/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/glob/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/global-agent": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz", @@ -4311,43 +4359,6 @@ "balanced-match": "^1.0.0" } }, - "node_modules/mocha/node_modules/glob": { - "version": "10.4.5", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", - "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", - "dev": true, - "license": "ISC", - "dependencies": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "bin": { - "glob": "dist/esm/bin.mjs" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/mocha/node_modules/glob/node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "license": "ISC", - "dependencies": { - "brace-expansion": "^2.0.1" - }, - "engines": { - "node": ">=16 || 14 >=14.17" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/mocha/node_modules/minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", @@ -8078,6 +8089,40 @@ "get-intrinsic": "^1.2.6" } }, + "glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "dev": true, + "requires": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "dependencies": { + "brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "requires": { + "balanced-match": "^1.0.0" + } + }, + "minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "requires": { + "brace-expansion": "^2.0.1" + } + } + } + }, "glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -8772,31 +8817,6 @@ "balanced-match": "^1.0.0" } }, - "glob": { - "version": "10.4.5", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", - "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", - "dev": true, - "requires": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "dependencies": { - "minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "requires": { - "brace-expansion": "^2.0.1" - } - } - } - }, "minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index e6ed2bdb9e17b..de8d631362db7 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -33,6 +33,7 @@ "version": "1.24.0", "license": "MIT", "devDependencies": { + "globby": "^15.0.0", "typedoc": "^0.25.7" } }, @@ -61,15 +62,15 @@ } }, "node_modules/@babel/code-frame": { - "version": "7.26.2", - "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", - "integrity": "sha512-RJlIHRueQgwWitWgF8OdFYGZX328Ax5BCemNGlqHfplnRT9ESi8JkFlvaVYbS+UubVY6dpv87Fs2u5M29iNFVQ==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", + "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-validator-identifier": "^7.25.9", + "@babel/helper-validator-identifier": "^7.27.1", "js-tokens": "^4.0.0", - "picocolors": "^1.0.0" + "picocolors": "^1.1.1" }, "engines": { "node": ">=6.9.0" @@ -410,9 +411,9 @@ } }, "node_modules/@babel/helper-string-parser": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", - "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", "dev": true, "license": "MIT", "engines": { @@ -420,9 +421,9 @@ } }, "node_modules/@babel/helper-validator-identifier": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", - "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", + "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", "dev": true, "license": "MIT", "engines": { @@ -455,27 +456,27 @@ } }, "node_modules/@babel/helpers": { - "version": "7.25.6", - "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.25.6.tgz", - "integrity": "sha512-Xg0tn4HcfTijTwfDwYlvVCl43V6h4KyVVX2aEm4qdO/PC6L2YvzLHFdmxhoeSA3eslcE6+ZVXHgWwopXYLNq4Q==", + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.28.4.tgz", + "integrity": "sha512-HFN59MmQXGHVyYadKLVumYsA9dBFun/ldYxipEjzA4196jpLZd8UjEEBLkbEkvfYreDqJhZxYAWFPtrfhNpj4w==", "dev": true, "license": "MIT", "dependencies": { - "@babel/template": "^7.25.0", - "@babel/types": "^7.25.6" + "@babel/template": "^7.27.2", + "@babel/types": "^7.28.4" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/parser": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.9.tgz", - "integrity": "sha512-81NWa1njQblgZbQHxWHpxxCzNsa3ZwvFqpUg7P+NNUU6f3UU2jBEg4OlF/J6rl8+PQGh1q6/zWScd001YwcA5A==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.5.tgz", + "integrity": "sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==", "dev": true, "license": "MIT", "dependencies": { - "@babel/types": "^7.26.9" + "@babel/types": "^7.28.5" }, "bin": { "parser": "bin/babel-parser.js" @@ -2114,35 +2115,25 @@ } }, "node_modules/@babel/runtime": { - "version": "7.25.6", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.6.tgz", - "integrity": "sha512-VBj9MYyDb9tuLq7yzqjgzt6Q+IBQLrGZfdjOekyEirZPHxXWoTSGUTMrpsfi58Up73d13NfYLv8HT9vmznjzhQ==", + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz", + "integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==", "dev": true, "license": "MIT", - "dependencies": { - "regenerator-runtime": "^0.14.0" - }, "engines": { "node": ">=6.9.0" } }, - "node_modules/@babel/runtime/node_modules/regenerator-runtime": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", - "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==", - "dev": true, - "license": "MIT" - }, "node_modules/@babel/template": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.26.9.tgz", - "integrity": "sha512-qyRplbeIpNZhmzOysF/wFMuP9sctmh2cFzRAZOn1YapxBsE1i9bJIY586R/WBLfLcmcBlM8ROBiQURnnNy+zfA==", + "version": "7.27.2", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", + "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", "dev": true, "license": "MIT", "dependencies": { - "@babel/code-frame": "^7.26.2", - "@babel/parser": "^7.26.9", - "@babel/types": "^7.26.9" + "@babel/code-frame": "^7.27.1", + "@babel/parser": "^7.27.2", + "@babel/types": "^7.27.1" }, "engines": { "node": ">=6.9.0" @@ -2189,14 +2180,14 @@ "license": "MIT" }, "node_modules/@babel/types": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.9.tgz", - "integrity": "sha512-Y3IR1cRnOxOCDvMmNiym7XpXQ93iGDDPHx+Zj+NM+rg0fBaShfQLkg+hKPaZCEvg5N/LeCo4+Rj/i3FuJsIQaw==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.5.tgz", + "integrity": "sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-string-parser": "^7.25.9", - "@babel/helper-validator-identifier": "^7.25.9" + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.28.5" }, "engines": { "node": ">=6.9.0" @@ -3319,9 +3310,9 @@ } }, "node_modules/babel-plugin-module-resolver/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -3477,7 +3468,9 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.11", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", "dev": true, "license": "MIT", "dependencies": { @@ -3831,9 +3824,9 @@ } }, "node_modules/compression": { - "version": "1.8.0", - "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.0.tgz", - "integrity": "sha512-k6WLKfunuqCYD3t6AsuPGvQWaKwuLLh2/xHNcX4qE+vIfDNXpSqnrhwA7O53R7WVQUnt8dVAIW+YHr7xTgOgGA==", + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz", + "integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==", "dev": true, "license": "MIT", "dependencies": { @@ -3841,7 +3834,7 @@ "compressible": "~2.0.18", "debug": "2.6.9", "negotiator": "~0.6.4", - "on-headers": "~1.0.2", + "on-headers": "~1.1.0", "safe-buffer": "5.2.1", "vary": "~1.1.2" }, @@ -4821,9 +4814,9 @@ } }, "node_modules/image-size": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.0.tgz", - "integrity": "sha512-4S8fwbO6w3GeCVN6OPtA9I5IGKkcDMPcKndtUlpJuCwu7JLjtj7JZpwqLuyY2nrmQT3AWsCJLSKPsc2mPBSl3w==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.1.tgz", + "integrity": "sha512-rH+46sQJ2dlwfjfhCyNx5thzrv+dtmBIhPHk0zgRUukHzZ/kRueTJXoYYsclBaKcSMBWuGbOFXtioLpzTb5euw==", "dev": true, "license": "MIT", "dependencies": { @@ -5250,7 +5243,9 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "3.14.1", + "version": "3.14.2", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.2.tgz", + "integrity": "sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==", "dev": true, "license": "MIT", "dependencies": { @@ -6544,9 +6539,9 @@ } }, "node_modules/on-headers": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", - "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz", + "integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==", "dev": true, "license": "MIT", "engines": { @@ -7130,9 +7125,9 @@ "license": "Python-2.0" }, "node_modules/react-native-builder-bob/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -7203,9 +7198,9 @@ } }, "node_modules/react-native-builder-bob/node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 6a8dffb73fa08..f0f7527f665b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -360,7 +360,7 @@ const createInPlaceSoftmaxProgramInfo = ( let local_offset = local_idx * uniforms.elements_per_thread; let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset; let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'}; - var thread_max_vector = ${f32Type}(-3.402823e+38f); + var thread_max_vector = ${f32Type}(-3.4028234663852886e+38f); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } @@ -378,7 +378,7 @@ const createInPlaceSoftmaxProgramInfo = ( })()}; workgroupBarrier(); - var max_value = f32(-3.402823e+38f); + var max_value = f32(-3.4028234663852886e+38f); for (var i = 0u; i < ${WG}; i++) { max_value = max(thread_max[i], max_value); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 2056416873df5..f6882280e91df 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -81,7 +81,7 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt // 6.2.4 in wgsl spec const threadMaxDecl = tensorTypeToWsglStorageType(transposedInput.dataType) === 'f32' - ? `var threadMax = ${valueType}(-3.402823e+38f);` + ? `var threadMax = ${valueType}(-3.4028234663852886e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (shaderHelper: ShaderHelper) => ` var rowMaxShared : ${valueType}; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu deleted file mode 100644 index b40fc2bf0eef8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cu +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/attention.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" -#include "contrib_ops/rocm/bert/transformer_common.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr int kPastSequenceLengthInputIndex = 6; -constexpr int kPastInputIndex = 4; -constexpr int kPresentOutputIndex = 1; - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Attention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(kPastInputIndex, kPresentOutputIndex) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ - Attention); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -template -Attention::Attention(const OpKernelInfo& info) - : RocmKernel(info), AttentionBase(info, true), attn_type_(kAttention) { - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - tunable_op_ = std::make_shared(); -} - -template -Status Attention::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* weights = context->Input(1); - const Tensor* bias = context->Input(2); - const Tensor* mask_index = context->Input(3); - const Tensor* past = context->Input(4); - const Tensor* attention_bias = context->Input(5); - const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); - - auto& device_prop = GetDeviceProp(); - RocmAttentionParameters attn; - ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), - weights->Shape(), - bias->Shape(), - mask_index, - past, - attention_bias, - &attn, - device_prop.maxThreadsPerBlock, - past_seq_len)); - ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention - ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(attn.batch_size); - output_shape[1] = static_cast(attn.sequence_length); - output_shape[2] = static_cast(attn.v_hidden_size); - Tensor* output = context->Output(0, output_shape); - - std::vector present_dims{ - 2, attn.batch_size, attn.num_heads, - past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, - attn.head_size}; - TensorShape present_shape(present_dims); - Tensor* present = context->Output(kPresentOutputIndex, present_shape); - - auto stream = Stream(context); - hipblasHandle_t hipblas = GetHipblasHandle(context); - - using HipT = typename ToHipType::MappedType; - using QkvProjectGeneric = GemmPermuteGenericPipeline; - using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - - ORT_RETURN_IF_ERROR(ClassifyAttentionMode(attn_type_, &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present})); - ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE || - attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE || - attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE || - attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE || - attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE); - - size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn); - size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn), - AttentionGeneric::GetWorkspaceNumBytes(&attn)); - if (GetTuningContext()->IsTunableOpEnabled()) { - shared_workspace_bytes = std::max(shared_workspace_bytes, AttentionTunableOp::GetWorkspaceNumBytes(&attn)); - } - - auto qkv_project_output = GetScratchBuffer(qkv_project_output_bytes, context->GetComputeStream()); - auto workspace = GetScratchBuffer(shared_workspace_bytes, context->GetComputeStream()); - - GemmPermuteParams gemm_permute_params; - { - auto& params = gemm_permute_params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = hipblas; - params.attention = &attn; - params.device_prop = &device_prop; - - params.input_buffer = reinterpret_cast(input->DataRaw()); - params.weight_buffer = reinterpret_cast(weights->DataRaw()); - params.bias_buffer = reinterpret_cast(bias->DataRaw()); - params.out_buffer = reinterpret_cast(qkv_project_output.get()); - params.ones = GetConstOnes(attn.batch_size * attn.sequence_length, stream); - params.workspace_buffer = reinterpret_cast(workspace.get()); - } - - ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params)); - auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params); - - // NOTE: GemmPermute always output 3BNSH, k_buffer and v_buffer can be treated as 2BNSH - if (nullptr != present) { - Strides dst_strides; // the output buffer is present Tensor, the buffer is the same - - int4 add_shape{2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size}; - HipT* add_dest = nullptr; // destination of concatenated data to present - const HipT* const add_src = k_buffer; // source of concatenated data to present - const auto add_src_strides = Strides::BNSHMemory( - 2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size); - - if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; - } else if (attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - - // We only need to copy past to present in this case. All other cases will be build the present incrementally - const int4 past_shape = {2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; - HipT* const past_dest = reinterpret_cast(present->MutableDataRaw()); - const HipT* const past_src = reinterpret_cast(past->DataRaw()); - const Strides past_src_strides = Strides::BNSHMemory( - 2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); - - ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, past_src, past_shape, past_src_strides.ForBNSHCoord(), - past_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } else if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; - } else if (attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - } - - ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, add_src, add_shape, add_src_strides.ForBNSHCoord(), - add_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - - // update pointers to present_k and present_v. TODO: switch to ConvertToOffsetedBufferViews - k_buffer = reinterpret_cast(present->MutableDataRaw()); - v_buffer = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(attn.batch_size, 0, 0, 0); - } - - // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax - const TransformerOptions* options = TransformerOptions::GetInstance(); - bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); - - GemmSoftmaxGemmPermuteParams gemm_softmax_gemm_permute_params; - { - auto& params = gemm_softmax_gemm_permute_params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = hipblas; - params.attention = &attn; - params.device_prop = &device_prop; - // FIXME: the params.scale seems to be different from AttentionParameters::scale; - params.scale = 1.0f / sqrt(static_cast(attn.head_size)); - // TODO: switch to ConvertToOffsetedBufferViews - params.q_buffer = q_buffer; - params.k_buffer = k_buffer; - params.v_buffer = v_buffer; - params.out_buffer = reinterpret_cast(output->MutableDataRaw()); - - if (attention_bias != nullptr) { - params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); - } - - if (mask_index != nullptr) { - params.mask_index_buffer = mask_index->Data(); - params.mask_index_dims = mask_index->Shape().AsShapeVector(); - } - - params.workspace_buffer = reinterpret_cast(workspace.get()); - } - - if (this->GetTuningContext()->IsTunableOpEnabled() && - !use_persistent_softmax) { - return (*std::static_pointer_cast(tunable_op_))(&gemm_softmax_gemm_permute_params); - } else { - return AttentionGeneric::Run(&gemm_softmax_gemm_permute_params, use_persistent_softmax); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.h b/onnxruntime/contrib_ops/rocm/bert/attention.h deleted file mode 100644 index 7204fd660a516..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class Attention final : public RocmKernel, public AttentionBase { - public: - Attention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - public: - AttentionType attn_type_; - - // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: - // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. - // 2. We don't want to construct the object repeatly (which is expansive) during Compute. - std::shared_ptr tunable_op_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu deleted file mode 100644 index 270a8e51daf88..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ /dev/null @@ -1,435 +0,0 @@ -/* - The implementation of this file is based on qkvToContext plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Modifications: scaling is moved from masked softmax to the gemm before that. -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/attention_softmax.h" -#include "contrib_ops/rocm/bert/decoder_attention_impl.h" - -using namespace onnxruntime::rocm; - -namespace blas = onnxruntime::rocm::tunable::blas; - -#define CHECK_ROCM(expr) HIP_RETURN_IF_ERROR(expr) - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -static size_t AlignTo(size_t a, size_t b) { - return CeilDiv(a, b) * b; -} - -size_t GetAttentionScratchSize(size_t element_size, - int batch_size, - int num_heads, - int sequence_length, - int total_sequence_length) { - const size_t bytes = element_size * batch_size * num_heads * sequence_length * total_sequence_length; - - const size_t alignment = 256; - const size_t bytesAligned = AlignTo(bytes, alignment); - return bytesAligned; -} - -size_t GetAttentionWorkspaceSize( - size_t element_size, - int batch_size, - int num_heads, - int head_size, - int sequence_length, - int total_sequence_length) { - size_t qkv_size = element_size * 3 * batch_size * sequence_length * num_heads * head_size; - return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, - sequence_length, total_sequence_length); -} - -inline int3 Get2DMaskStrides(int total_sequence_length) { - // stride == 0 indicate broadcasting - return {total_sequence_length, 0, 1}; -} - -Status ClassifyAttentionMode( - AttentionType attn_type, - RocmAttentionParameters* attn, - const std::vector& qkv, - const std::vector& past, - const std::vector& present) { - size_t num_qkv = std::count_if(qkv.cbegin(), qkv.cend(), [](auto it) { return it != nullptr; }); - size_t num_past = std::count_if(past.cbegin(), past.cend(), [](auto it) { return it != nullptr; }); - size_t num_present = std::count_if(present.cbegin(), present.cend(), [](auto it) { return it != nullptr; }); - - auto hint = MakeString(num_qkv, " qkv inputs, ", num_past, " past inputs and ", num_present, " present inputs"); - LOGS_DEFAULT(VERBOSE) << hint; - - if (attn_type == kAttention) { - ORT_ENFORCE(num_qkv == 0); - if (num_past == 0 && num_present == 0) { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE; - return Status::OK(); - } else if (num_past == 0 && num_present == 1) { - if (attn->past_present_share_buffer == false) { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE; - return Status::OK(); - } else { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE; - return Status::OK(); - } - } else if (num_past == 1 && num_present == 1) { - if (attn->past_present_share_buffer == false) { - attn->mode = QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE; - return Status::OK(); - } else { - attn->mode = QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE; - return Status::OK(); - } - } - } else if (attn_type == kMultiHeadAttention || attn_type == kDecoderMaskedMultiHeadAttention) { - if (num_qkv == 3 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } else if (num_qkv == 3 && num_past == 0 && num_present == 2) { - if (attn->past_present_share_buffer == false) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; - return Status::OK(); - } - } else { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; - return Status::OK(); - } - } - } else if (num_qkv == 3 && num_past == 2 && num_present == 2) { - if (attn->past_present_share_buffer == false) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; - return Status::OK(); - } - } else { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; - return Status::OK(); - } - } - } else if (num_qkv == 1 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == QKV_BSN3H) { - attn->mode = BLN3H_NONE_NONE_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } else if (num_qkv == 2 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == Q_KV_BSNH_BSN2H) { - attn->mode = BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } - } - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Unsupported AttentionMode for ", attn_type, ". Got qkv format ", attn->qkv_format, - ". Got ", hint); -} - -template -Status DecoderQkvToContext( - const hipDeviceProp_t& prop, - RocmTuningContext* tuning_ctx, - Stream* ort_stream, - hipblasHandle_t& hipblas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) { - const int max_threads_per_block = prop.maxThreadsPerBlock; - const int BN = batch_size * num_heads; - const int BHN = BN * head_size; - const int BNS = BN * sequence_length; - const int k_buffer_offset = sequence_length * BHN; - const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; - - T* temp_qkv_buffer = workspace_buffer; - auto stream = static_cast(ort_stream->GetHandle()); - - const T* q = qkv_buffer; - // transpose q and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, - num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); - - const T* k = qkv_buffer + k_buffer_offset; - const T* v = qkv_buffer + v_buffer_offset; - if (!has_layer_state || !use_past) { - if (!static_kv) { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } else { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } - } else { - if (!static_kv) { - // transpose kv and copy them to temp_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); - // concat cache-k with k and copy to qkv_buffer - if (nullptr != key_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, - batch_size, head_size, num_heads, - max_threads_per_block, 1, key_cache, - temp_qkv_buffer, qkv_buffer + k_buffer_offset)); - } - // concat cache-v with v and copy to qkv_buffer - if (nullptr != value_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, - batch_size, head_size, num_heads, - max_threads_per_block, 1, value_cache, - temp_qkv_buffer + k_buffer_offset, - qkv_buffer + v_buffer_offset)); - } - } - } - - if (has_layer_state) { - if (use_past && static_kv) { - CHECK_ROCM(hipMemcpyAsync(new_key_cache, key_cache, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - CHECK_ROCM(hipMemcpyAsync(new_value_cache, value_cache, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } else { - CHECK_ROCM(hipMemcpyAsync(new_key_cache, k, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - CHECK_ROCM(hipMemcpyAsync(new_value_cache, v, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } - } - - // scratch1: BxNxSxS* buffer - // scratch2: BxNxSxS* buffer - // scratch3: BxNxSxH buffer - T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; - T* scratch2 = scratch1 + BNS * kv_sequence_length; - T* scratch3 = scratch2 + BNS * kv_sequence_length; - - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* - // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int temp_matrix_size = sequence_length * kv_sequence_length; - - const int strideA = kv_sequence_length * head_size; - const int strideB = sequence_length * head_size; - if (use_past && static_kv) { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::Trans, blas::BlasOp::NonTrans, - kv_sequence_length, sequence_length, head_size, - /*alpha=*/rsqrt_head_size, - key_cache, head_size, strideA, - q, head_size, strideB, - /*beta=*/0.0f, - scratch1, kv_sequence_length, temp_matrix_size, - BN)); - } else { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::Trans, blas::BlasOp::NonTrans, - kv_sequence_length, sequence_length, head_size, - /*alpha=*/rsqrt_head_size, - k, head_size, strideA, - q, head_size, strideB, - /*beta=*/0.0f, - scratch1, kv_sequence_length, temp_matrix_size, - BN)); - } - - if (has_key_padding_mask) { - int3 strides = Get2DMaskStrides(kv_sequence_length); - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( - ort_stream, kv_sequence_length, sequence_length, batch_size, num_heads, - strides, nullptr, key_padding_mask, nullptr, scratch1, scratch2, - false, 1.0f, false, nullptr, mask_filter_value)); - } else { - ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, scratch1, scratch2, false)); - } - - // compute P*V (as V*P), and store in scratch3: BxNxSxH - if (use_past && static_kv) { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - head_size, sequence_length, kv_sequence_length, - /*alpha=*/1.0f, - value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - /*beta=*/0.0f, - scratch3, head_size, strideB, - BN)); - } else { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - head_size, sequence_length, kv_sequence_length, - /*alpha=*/1.0f, - v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - /*beta=*/0.0f, - scratch3, head_size, strideB, - BN)); - } - - // scratch3 is BxNxSxH, transpose to output SxBxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, - num_heads, max_threads_per_block, true, scratch3, output); -} - -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, - RocmTuningContext* tuning_ctx, - Stream* stream, - hipblasHandle_t& hipblas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) { - if (element_size == 2) { - return DecoderQkvToContext( - prop, - tuning_ctx, - stream, - hipblas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } else { - return DecoderQkvToContext( - prop, - tuning_ctx, - stream, - hipblas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h deleted file mode 100644 index 07d875e90fa4b..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "contrib_ops/cpu/bert/attention_common.h" -#include "contrib_ops/cpu/bert/attention_parameters.h" -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -typedef struct __align__(32) { - long long int x, y, z, w; -} LongLong4; - -size_t GetAttentionScratchSize( - size_t element_size, - int batch_size, - int num_heads, - int sequence_length, - int all_sequence_length); - -size_t GetAttentionWorkspaceSize( - size_t element_size, - int batch_size, - int num_heads, - int head_size, - int sequence_length, - int past_sequence_length); - -Status LaunchTransCtx(hipStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); - -Status LaunchTransCtx(hipStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const half* input, half* output); - -Status LaunchTransQkv(hipStream_t stream, const int matrix_num, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const float* input, float* output, - int total_matrix_count = -1); - -Status LaunchTransQkv(hipStream_t stream, const int matrix_num, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, - int total_matrix_count = -1); - -Status LaunchConcatTensorToTensor(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const float* tensor_in, - const float* tensor_add, - float* tensor_out); - -Status LaunchConcatTensorToTensor(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const half* tensor_in, - const half* tensor_add, - half* tensor_out); - -inline hipblasStatus_t _compat_hipblas_gemm_strided_batched_ex(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const void* alpha, - const void* A, - hipDataType a_type, - int lda, - hipblasStride stride_A, - const void* b, - hipDataType b_type, - int ldb, - hipblasStride stride_b, - const void* beta, - void* c, - hipDataType c_type, - int ldc, - hipblasStride stride_c, - int batch_count, - hipblasComputeType_t compute_type, - hipblasGemmAlgo_t algo) { - return hipblasGemmStridedBatchedEx(handle, - transa, - transb, - m, // m - n, // n - k, // k - alpha, // alpha - A, // A - a_type, // A type - lda, // lda - stride_A, // strideA - b, // B - b_type, // B type - ldb, // ldb - stride_b, // strideB - beta, // beta - c, // C - c_type, // C type - ldc, // ldc - stride_c, // strideC - batch_count, // batch count - compute_type, - algo); -} - -// Compatible for CublasMathModeSetter -class CompatHipblasMathModeSetter { - public: - CompatHipblasMathModeSetter(const hipDeviceProp_t&, - hipblasHandle_t, - int) { - } -}; - -enum AttentionMode { - // Q,K,V,PastK,PastV,PresentK,PresentV - QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, - QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE, - QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE, - QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE, - QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE, - BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE, - BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE, - BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH, - BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH, - BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH, - BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH, - BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH, - BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH, - BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH, - BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH, - BLN3H_NONE_NONE_NONE_NONE_NONE_NONE, - BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE, -}; - -struct RocmAttentionParameters : AttentionParameters { - AttentionMode mode; -}; - -Status ClassifyAttentionMode(AttentionType type, - RocmAttentionParameters* attn, - const std::vector& qkv, - const std::vector& past, - const std::vector& present); - -template -Status LaunchStridedCopy( - hipStream_t stream, - const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) - T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) - int max_threads_per_block); - -template -Status LaunchStridedCopy(hipStream_t stream, - const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h) - T* out, LongLong4 out_strides, // coord (b,n,s,h) - int max_threads_per_block); -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h deleted file mode 100644 index 9f2faa228cf79..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ /dev/null @@ -1,465 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on qkvToContext plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#pragma once - -#include -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/math/softmax.h" - -#define ROCMRT_INF_F __int_as_float(0x7f800000) - -using namespace onnxruntime::rocm; -using namespace hipcub; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -__device__ inline void Softmax(const int all_sequence_length, - const int valid_end, - const int valid_start, - const T* attn_bias, - const T* input, - T* output) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - float thread_data_max(-ROCMRT_INF_F); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - for (int i = threadIdx.x; i < valid_end; i += TPB) { - if (i >= valid_start) { - const int index = offset + i; - float input_at_idx = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - if (thread_data_max < input_at_idx) { - thread_data_max = input_at_idx; - } - } - } - - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max()); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_sum(0.f); - for (int i = threadIdx.x; i < valid_end; i += TPB) { - if (i >= valid_start) { - const int index = offset + i; - float val = attn_bias == nullptr ? input[index] : input[index] + attn_bias[index]; - thread_data_sum += expf(val - max_block); - } - } - - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, hipcub::Sum()); - if (threadIdx.x == 0) { - sum_reverse_block = 1.f / sum; - } - __syncthreads(); - - for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { - const int index = offset + i; - float input_at_idx = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; - output[index] = T(val); - } -} - -template -__device__ inline void SoftmaxSmall(const int all_sequence_length, - const int sequence_length, - const int valid_end, - const int valid_start, - const T* attn_bias, - const T* input, - T* output, - bool causal) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - const int index = offset + threadIdx.x; - - bool is_valid = false; // whether it has attention mask == 1. - - // Update end position for causal. - int end = valid_end; - if (causal) { - const int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; - if (end_causal < end) { - end = end_causal; - } - } - - is_valid = (threadIdx.x >= valid_start && threadIdx.x < end); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float input_data = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F); - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_exp(0.f); - if (is_valid) { - thread_data_exp = expf(input_data - max_block); - } - - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), end); - - // Store value of 1.0/sum. - if (threadIdx.x == 0) { - sum_reverse_block = (1.f) / sum; - } - __syncthreads(); - - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. - if (threadIdx.x < all_sequence_length) { - output[index] = is_valid ? T(thread_data_exp * sum_reverse_block) : T(0.f); - } -} - -// Note about the attention_mask_strides and attention_mask/key_padding_mask -// attention_mask accepts 2D, 3D or 4D tensor, but it will be viewed as 3D tensor uniformally and it will be indexed -// as [batch_index, sequence_index, token_index]. -template -__global__ void SoftmaxWithRawMaskSmallKernel( - const int all_sequence_length, - const int sequence_length, - const int3 attention_mask_strides, - const int* attention_mask, // 2D, 3D or 4D attention mask - const bool* key_padding_mask, - const T* attn_bias, - const T* input, - T* output, - const bool causal, - const float rsqrt_head_size, - const bool skip_softmax, - const float mask_filter_value) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; - - // Mask all thread_data values to negative infinity to allow BlockReduce Max operation over all thread_data - // members with all invalid members set to a value that does not impact the final result. This is necessary - // to avoid the performance impact from using the valid_items interface. - float thread_data = -ROCMRT_INF_F; - if (threadIdx.x < all_sequence_length) { - thread_data = float(input[index]) * rsqrt_head_size; - - const int sequence_index = blockIdx.x % sequence_length; - if (causal) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. - if (threadIdx.x > from_index) { - thread_data = -ROCMRT_INF_F; - } - } - - const int batch_index = blockIdx.y; - int mask_offset = attention_mask_strides.x * batch_index + - attention_mask_strides.y * sequence_index + - attention_mask_strides.z * threadIdx.x; - - if (nullptr == key_padding_mask) { - const int& mask = attention_mask[mask_offset]; - if (mask == 0) - thread_data += mask_filter_value; - } else { - const bool mask = key_padding_mask[mask_offset]; - if (mask) { - thread_data = -ROCMRT_INF_F; - } - } - - if (attn_bias != nullptr) { - thread_data += float(attn_bias[index]); - } - } - - if (skip_softmax) { - if (threadIdx.x < all_sequence_length) { - output[index] = T(thread_data); - } - return; - } - - const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max()); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - // Mask all thread_data_exp values to zero to allow BlockReduce Sum operation over all thread_data_exp - // members with all invalid members set to a value that does not impact the final result. This is necessary - // to avoid the performance impact from using the valid_items interface. - float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum()); - - // Store value of 1.0/sum - if (threadIdx.x == 0) { - sum_reverse_block = (1.f) / sum; - } - __syncthreads(); - - if (threadIdx.x < all_sequence_length) { - output[index] = T(thread_data_exp * sum_reverse_block); - } -} - -template -__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, - const T* attn_bias, const T* input, T* output, bool causal) { - SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, - attn_bias, input, output, causal); -} - -template -__global__ void SoftmaxKernel(const int all_sequence_length, const T* attn_bias, const T* input, T* output) { - Softmax(all_sequence_length, all_sequence_length, 0, attn_bias, input, output); -} - -template -Status ComputeSoftmax( - hipStream_t stream, - const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const T* attn_bias, const T* input, T* output, bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - if (all_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (!causal) { - const int blockSize = 1024; - SoftmaxKernel<<>>( - all_sequence_length, attn_bias, input, output); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - - return HIP_CALL(hipPeekAtLastError()); -} - -template -__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, - const int* mask_end, const int* mask_start, - const T* attn_bias, const T* input, T* output, - bool causal) { - __shared__ int start_position; - __shared__ int end_position; - - if (threadIdx.x == 0) { - const int batch = blockIdx.y; - start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); - - // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. - if (start_position >= end_position) { - start_position = 0; - end_position = all_sequence_length; - } - } - __syncthreads(); - - SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, - attn_bias, input, output, causal); -} - -template -__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int* mask_end, const int* mask_start, - const T* attn_bias, const T* input, T* output) { - __shared__ int start_position; - __shared__ int end_position; - - if (threadIdx.x == 0) { - const int batch = blockIdx.y; - start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); - - // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. - if (start_position >= end_position) { - start_position = 0; - end_position = all_sequence_length; - } - } - __syncthreads(); - - Softmax(all_sequence_length, end_position, start_position, attn_bias, input, output); -} - -template -Status ComputeSoftmaxWithMask1D( - hipStream_t stream, - const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* mask_index, const int* mask_start, - const T* attn_bias, const T* input, T* output, const bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - -#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ - MaskedSoftmaxKernelSmall<<>>( \ - all_sequence_length, sequence_length, mask_index, mask_start, \ - attn_bias, input, output, causal); - - if (all_sequence_length <= 32) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); - } else if (all_sequence_length <= 64) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); - } else if (all_sequence_length <= 128) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); - } else if (all_sequence_length <= 256) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); - } else if (all_sequence_length <= 512) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); - } else if (all_sequence_length <= 1024) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); - } else if (!causal) { - const int blockSize = 1024; - MaskedSoftmaxKernel<<>>( - all_sequence_length, mask_index, mask_start, - attn_bias, input, output); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - -#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE - - return HIP_CALL(hipPeekAtLastError()); -} - -template -Status ComputeSoftmaxWithRawMask(Stream* ort_stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int num_heads, - const int3 attention_mask_strides, - const int* attention_mask, - const bool* key_padding_mask, - const T* attn_bias, - const T* input, - T* output, - const bool causal, - const float rsqrt_head_size, - const bool use_persistent_softmax, - T* persistent_softmax_workspace, - const float mask_filter_value) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - - T* out = use_persistent_softmax ? persistent_softmax_workspace : output; - auto stream = static_cast(ort_stream->GetHandle()); - -#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ - SoftmaxWithRawMaskSmallKernel<<>>( \ - all_sequence_length, sequence_length, attention_mask_strides, \ - attention_mask, key_padding_mask, attn_bias, input, out, \ - causal, rsqrt_head_size, \ - use_persistent_softmax, mask_filter_value); - - if (all_sequence_length <= 32) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); - } else if (all_sequence_length <= 64) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); - } else if (all_sequence_length <= 128) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); - } else if (all_sequence_length <= 256) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); - } else if (all_sequence_length <= 512) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); - } else if (all_sequence_length <= 1024) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - -#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE - - if (use_persistent_softmax) { - return dispatch_warpwise_softmax_forward(ort_stream, - output, - persistent_softmax_workspace, - all_sequence_length, - all_sequence_length, - batch_size * num_heads * sequence_length); - } - - return HIP_CALL(hipPeekAtLastError()); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh deleted file mode 100644 index 213940f132963..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_common.h" -#include "contrib_ops/cpu/bert/attention_parameters.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace blas = onnxruntime::rocm::tunable::blas; - -namespace { -std::tuple GetQkvProjectGemmMNKBatch(const AttentionParameters* attention) { - int m = attention->sequence_length; - int n = (attention->hidden_size + attention->hidden_size + attention->v_hidden_size); // q + k + v - int k = attention->input_hidden_size; - int batch = attention->batch_size; - return {m, n, k, batch}; -} -} // namespace - -template -struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams { - std::string Signature() const override { - auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(attention); - return MakeString("M", m, "_N", n, "_K", k, "_B", batch); - } - - hipblasHandle_t handle; - const AttentionParameters* attention; - const hipDeviceProp_t* device_prop; - - const T* input_buffer; - const T* weight_buffer; - const T* bias_buffer; - T* out_buffer; - - int3 bias_strides; - - const T* ones; // used for broadcasting bias if the underlying algorithm does not support strides - T* workspace_buffer; -}; - -template -struct GemmPermuteGenericPipeline { - inline static size_t GetOutputNumBytes(const AttentionParameters* attn) { - auto [m, n, _, batch] = GetQkvProjectGemmMNKBatch(attn); - return sizeof(T) * m * n * batch; - } - - inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) { - return GetOutputNumBytes(attn); - } - - inline static std::tuple GetGemmMNK(const GemmPermuteParams* params) { - auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(params->attention); - return {batch * m, n, k}; - } - - inline static std::tuple UnspliceOutputQKV(const GemmPermuteParams* params) { - auto* attn = params->attention; - int64_t batch = attn->batch_size * attn->num_heads; - int64_t num_elems_per_batch = attn->sequence_length * attn->head_size; - int64_t num_elems = batch * num_elems_per_batch; - auto q = params->out_buffer + 0 * num_elems; - auto k = params->out_buffer + 1 * num_elems; - auto v = params->out_buffer + 2 * num_elems; - return {q, k, v}; - } - - inline static Status BroadcastBias(const GemmPermuteParams* params) { - auto [m, n, k] = GetGemmMNK(params); - // Bias shape is (N), broadcast using B(M, N) = ones(M, 1) x bias(1, N). - // TODO: use custom kernel of expand to improve the performance. - return blas::row_major::Gemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, n, 1, - /*alpha=*/1.0f, - params->ones, 1, - params->bias_buffer, n, - /*beta=*/0.0f, - params->workspace_buffer, n); - } - - inline static Status Gemm(const GemmPermuteParams* params) { - auto [m, n, k] = GetGemmMNK(params); - // result(M, N) = input x weights + bias. - return blas::row_major::Gemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, n, k, - /*alpha=*/1.0f, - params->input_buffer, k, - params->weight_buffer, n, - /*beta=*/1.0f, - params->workspace_buffer, n); - } - - inline static Status Permute0213(const GemmPermuteParams* params) { - auto* attn = params->attention; - // input should be BxSx3xNxH => gemm_buffer: 3xBxNxSxH - return LaunchTransQkv( - params->StreamHandle(), 3, attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, - params->device_prop->maxThreadsPerBlock, false, params->workspace_buffer, params->out_buffer); - } - - static Status Run(const GemmPermuteParams* params) { - ORT_RETURN_IF_ERROR(BroadcastBias(params)); - ORT_RETURN_IF_ERROR(Gemm(params)); - ORT_RETURN_IF_ERROR(Permute0213(params)); - return Status::OK(); - } -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh deleted file mode 100644 index be8508670e4b1..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh +++ /dev/null @@ -1,177 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#ifdef USE_COMPOSABLE_KERNEL -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/utility/data_type.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using F16 = ck::half_t; -using F32 = float; - -template -using S = ck::Sequence; - -using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecialization; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; - -static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; - -template -using device_batched_gemm_softmax_gemm_permute_instances = - std::tuple< - // clang-format off - // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias| - // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar| - // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector| - // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, -#if ROCM_VERSION >= 50500 - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, -#endif - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, - // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> - // clang-format on - >; - -struct PreSoftmaxAttentionScoreOp { - PreSoftmaxAttentionScoreOp(float scale) : scale_(scale) {} - - // non-biased, non-masked - __host__ __device__ void operator()(float& y, const float& x) const { - y = scale_ * x; - } - - // biased or converted masked - __host__ __device__ void operator()(float& y, const float& x, const F16& bias) const { - y = scale_ * x + ck::type_convert(bias); - } - - // biased and converted masked - __host__ __device__ void operator()(float& y, const float& x, const F16& bias, const F16& converted_mask) const { - y = scale_ * x + ck::type_convert(bias) + ck::type_convert(converted_mask); - } - - float scale_; -}; - -// Use this function to gat implementation -template -std::vector, - PassThrough, PassThrough, D0Op, PassThrough, PassThrough, - MaskingSpec>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances() { - return {}; -} - -// implemented in impl_{fp16,bf16}[_biased][_masked].cu -// fp16, non-biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -// fp16, biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -// fp16, biased, fp16 masked, basically, two bias -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -// fp16, biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -// fp16, biased, fp16 masked, basically, two bias -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu deleted file mode 100644 index 2e32a6594d164..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using NonBiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu deleted file mode 100644 index 91da8d9e1f9a8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu deleted file mode 100644 index b08123be18977..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh deleted file mode 100644 index 226b89cfb2b86..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ /dev/null @@ -1,915 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -/* About Computing in these Pipelines - -B: batch size of Attention Op. NOTE: To be disambiguated with batch size of GEMMs -S: sequence length -T: total sequence length -N: num of heads -H: head dimension - -The following use qkv_format == Q_K_V_BNSH (mode == BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE) as a example: - -BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op - -In QKV projection (prior to this pipeline): - /-> Q [B,S,N*H] ->Reshape-> [B,S,N,H] ->Permute0213-> [B,N,S,H] -X --o--> K [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] - \-> V [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] - -pre_softmax_attn_scores = Q*K' = [B,N,S,H] * [BxNxTxH]' = [B,N,S,T] Batched GEMM1 -pre_softmax_attn_scores_masked = pre_softmax_attn_scores * scale +? bias +? mask Scale Add Bias, +? is optional -attn_scores = softmax(pre_softmax_attn_scores_masked) = [B,N,S,T] Softmax -scaled_multi_head_attn = attn_scores * V = [B,N,S,T] * [B,N,T,H] = [B,N,S,H] Batched GEMM2 - -Op outputs scaled_multi_head_attn: -[B,N,S,H] ->Permute0213-> [B,S,N,H] ->Reshape-> [B,S,N*H] - - -For the computing of pre_softmax_attn_scores +? mask +? bias: - -GemmSoftmaxGemmPermuteGenericPipeline handles it in specialized softmax. TODO: remove it! - -CK in GemmSoftmaxGemmPermuteTunablePipeline - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_2d ---> [B,T] ---> [B,1,1,T] -/ - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_3d --> [B,S,T] --> [B,1,S,T] -/ - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_4d -> [B,1,M,M] -> [B,1,S,T] -/ M is max_sequence_length from megatron, we will create a - **sub-view** from original mask buffer - -For CK implementation, there will be four cases combined: -non-biased, non-masked, no special processing. - biased, non-masked, no special processing, add the mask directly. -non-biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with scaled Q*K'. - biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with bias and scaled Q*K'. - -Broadcast add is not actually perform the broadcasting, just broadcast the load operation from memory. The impl details -are in composable kernels. The scale and add logic is performed via Acc0ElementOp - -# Classified modes: - -| Q | K | V | past(K)| pastV | present(K)| presentV | Op, desc -| ---- | ---- | ---- | ------ | ----- | --------- | -------- | --------- -| QFMT | KFMT | VFMT | - | - | - | - | A, basic, qkv is impl dependent by qkv_format -| QFMT | KFMT | VFMT | 2BNPH | - | 2BNTH *^ | - | A, past_present_share_buffer = false, qkv is impl dependent by qkv_format -| QFMT | KFMT | VFMT | 2BNMH | - | 2BNMH *^ | - | A, past_present_share_buffer = true, qkv is impl dependent by qkv_format -| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic -| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true -| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false -| BSNH | BLNH | BLNH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false -| BSNH | BNLH | BNLH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false -| BSNH | BLNH | BLNH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true -| BSNH | BNLH | BNLH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true -| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false -| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false -| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true -| BSNH | BNLH | BNLH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true -| BLN3H*^| - | - | - | - | - | - | MHA basic, qkv_packed -| BSNH | BLN2H*^| - | - | - | - | - | MHA basic, kv_packed - -Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel - -About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_v for v_buffer - -- Marked with `*` indicate the Tensor is used for k_buffer passing. -- Marked with `^` indicate the Tensor is used for v_buffer passing. - -# Supported Op - -- A: Attention -- MHA: MultiHeadAttention - -# Dim Value - -- B: batch_size -- N: num_heads -- H: head_size - -- S: sequence_length -- L: kv_sequence_length -- P: past_sequence_length -- T: total_sequence_length = P + L -- M: max_sequence_length -*/ - -#include "core/framework/tensor_shape.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/attention_softmax.h" -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif // USE_COMPOSABLE_KERNEL - -#include -#include - -namespace blas = onnxruntime::rocm::tunable::blas; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -inline int3 Get2DMaskStrides(int total_sequence_length) { - // stride == 0 indicate broadcasting - return {total_sequence_length, 0, 1}; -} - -// A stride maps from natural coordinate to physical offset of underlying memory storage buffer offset. We need to -// specify both of the natural coordinate order, say (b,n,s,h), (b,s,n,h) or (b,n,h,s), and memory order, say BNSH or -// BSNH, to determain the strides. To obtain the offset, we just do the inner product of coordinate with the strides. -// This wrapper create the stride vector from the physical dimension (or physical shape). -struct Strides { - // Create the strides for BNSH physically indexed memory buffer - static Strides BNSHMemory(int batch_dim, - int num_head_dim, - int seqlen_dim, - int head_size_dim) { - ORT_UNUSED_PARAMETER(batch_dim); - return Strides{LongLong4{ - static_cast(num_head_dim) * seqlen_dim * head_size_dim, - static_cast(seqlen_dim) * head_size_dim, - static_cast(head_size_dim), - static_cast(1), - }}; - } - - // Create the strides for BSNH physically indexed memory buffer - static Strides BSNHMemory(int batch_dim, - int seqlen_dim, - int num_head_dim, - int head_size_dim) { - ORT_UNUSED_PARAMETER(batch_dim); - return Strides{LongLong4{ - static_cast(seqlen_dim) * num_head_dim * head_size_dim, - static_cast(head_size_dim), - static_cast(num_head_dim) * head_size_dim, - static_cast(1), - }}; - } - - template - T ForBNSHCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.z), - static_cast(strides_for_bnsh_coord.w)}; - } - - template - T ForBSNHCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.z), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.w)}; - } - - template - T ForBNHSCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.w), - static_cast(strides_for_bnsh_coord.z)}; - } - - int64_t OffsetAt(int b, int n, int s, int h) const { - return strides_for_bnsh_coord.x * b + strides_for_bnsh_coord.y * n + - strides_for_bnsh_coord.z * s + strides_for_bnsh_coord.w * h; - } - - // store intermediate strides in the canonical (b,n,s,h) coordinate order - LongLong4 strides_for_bnsh_coord; -}; - -template -std::tuple ConvertToOffsetedBufferViews( - const RocmAttentionParameters* attn, - const T* query = nullptr, // q or packed_qkv - const T* key = nullptr, // k or packed kv - const T* value = nullptr, // - const T* present = nullptr, // present or present_k - const T* present_v = nullptr) { - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: { - return {reinterpret_cast(query), - reinterpret_cast(key), - reinterpret_cast(value)}; - } - case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: { - auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->total_sequence_length * - attn->head_size; - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present) + offset}; - } - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: { - auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->max_sequence_length * - attn->head_size; - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present) + offset}; - } - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present_v)}; - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: { - auto packed_kv = reinterpret_cast(key); - return {reinterpret_cast(query), packed_kv, packed_kv + attn->head_size}; - } - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: { - auto packed_qkv = reinterpret_cast(query); - return {packed_qkv, packed_qkv + 1 * attn->head_size, packed_qkv + 2 * attn->head_size}; - } - default: - ORT_ENFORCE("unreachable"); - return {}; - } -} - -inline std::tuple GetQkvStrides(const RocmAttentionParameters* attn) { - // G0 not used, because it is the slowest dimension - const int& B = attn->batch_size; - const int& N = attn->num_heads; - const int& S = attn->sequence_length; - const int& L = attn->kv_sequence_length; - const int& T = attn->total_sequence_length; - const int& M = attn->max_sequence_length; - const int& H = attn->head_size; - const int& Hv = attn->v_head_size; - - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - if (attn->qkv_format == Q_K_V_BNSH) { - return { - Strides::BNSHMemory(B, N, S, H), - Strides::BNSHMemory(B, N, L, H), - Strides::BNSHMemory(B, N, L, Hv), - }; - } else if (attn->qkv_format == Q_K_V_BSNH) { - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, H), - Strides::BSNHMemory(B, L, N, Hv), - }; - } - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, T, H), - Strides::BNSHMemory(B, N, T, Hv), - }; - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, M, H), - Strides::BNSHMemory(B, N, M, Hv), - }; - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, H), - Strides::BSNHMemory(B, L, N, Hv), - }; - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, L, H), - Strides::BNSHMemory(B, N, L, Hv), - }; - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, 2 * H), - Strides::BSNHMemory(B, L, N, 2 * Hv), - }; - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, L, N, 3 * H), - Strides::BSNHMemory(B, L, N, 3 * H), - Strides::BSNHMemory(B, L, N, 3 * Hv), - }; - default: - ORT_ENFORCE("unreachable"); - return {}; - } -} - -inline std::tuple GetRawMaskBufferAddrSizesAndStrides( - const int* buffer, const RocmAttentionParameters* attn) { - const int* offseted_buffer{buffer}; // how to view the mask buffer - int3 sizes{0, 0, 0}; // the logical shape of the view - int3 strides{-1, -1, -1}; // the physical memory layout - switch (attn->mask_type) { - case MASK_NONE: - case MASK_2D_DUMMY: - break; // No mask - case MASK_2D_KEY_PADDING: - sizes = {attn->batch_size, 1, attn->total_sequence_length}; - strides = Get2DMaskStrides(attn->total_sequence_length); - break; - case MASK_3D_ATTENTION: - sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; - strides = {attn->sequence_length * attn->total_sequence_length, attn->total_sequence_length, 1}; - break; - case MASK_4D_MEGATRON: - // offset to skip past sequence part, so that we can index it with [batch_index, sequence_index, token_index] - offseted_buffer = buffer + attn->past_sequence_length * attn->max_sequence_length; - sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; - strides = {attn->max_sequence_length * attn->max_sequence_length, attn->max_sequence_length, 1}; - break; - default: - LOGS_DEFAULT(FATAL) << "unsupported mask type: " << attn->mask_type; - throw std::runtime_error("unsupported mask type"); - } - return {offseted_buffer, sizes, strides}; -} - -template -struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { - std::string Signature() const override { - return MakeString( - "B", attention->batch_size, - "_S", attention->sequence_length, - "_T", attention->total_sequence_length, - "_N", attention->num_heads, - "_H", attention->head_size, - "_Hv", attention->v_head_size, - bias_buffer != nullptr ? "_B" : "_NB", - "_M", mask_index_dims.size(), - "_QKV", attention->qkv_format, - "_MODE", attention->mode); - } - - std::tuple GetGemmsMNKOBatch() const { - ORT_ENFORCE(attention != nullptr); - auto m = attention->sequence_length; - auto n = attention->total_sequence_length; - auto k = attention->head_size; - auto o = attention->v_head_size; - auto batch = attention->batch_size * attention->num_heads; - return {m, n, k, o, batch}; - } - - hipblasHandle_t handle; - const RocmAttentionParameters* attention; - const hipDeviceProp_t* device_prop; - - float scale; - const T* q_buffer; - const T* k_buffer; - const T* v_buffer; - T* out_buffer; - - // optional, attention bias [B,N,S,T] - // TODO: support shape [B,1,S,T], [1, N, S, T], [1, 1, S, T] with broadcast. - const T* bias_buffer{nullptr}; - - // optional, mask value - const int* mask_index_buffer{nullptr}; - TensorShapeVector mask_index_dims{}; - - // optional, internal - void* workspace_buffer{nullptr}; -}; - -inline bool IsKVBNMH(AttentionMode mode) { - switch (mode) { - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return true; - default: - return false; - } -} - -template -struct GemmSoftmaxGemmPermuteGenericPipeline { - static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams* params) { - return params->mask_index_buffer != nullptr && params->mask_index_dims.size() >= 2; - } - - static std::tuple GetWorkspacePlan(const GemmSoftmaxGemmPermuteParams* params) { - auto bytes = GetAttentionScratchSize( - sizeof(T), - params->attention->batch_size, - params->attention->num_heads, - params->attention->sequence_length, - params->attention->total_sequence_length); - auto gemm1_out = reinterpret_cast(params->workspace_buffer); - auto softmax_out = gemm1_out + (bytes / sizeof(T)); - auto gemm2_out = softmax_out + (bytes / sizeof(T)); - return {gemm1_out, softmax_out, gemm2_out}; - } - - inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { - return GetAttentionWorkspaceSize( - sizeof(T), - attn->batch_size, - attn->num_heads, - attn->head_size, - attn->sequence_length, - attn->total_sequence_length); - } - - inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams* params) { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - - int k_buffer_stride = n * k; - if (IsKVBNMH(params->attention->mode)) { - k_buffer_stride = params->attention->max_sequence_length * params->attention->head_size; - } - - // GEMM1 [m,k] * [n,k]' -> [m,n] - return blas::row_major::StridedBatchedGemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::Trans, - m, n, k, - // For raw attention mask, the scalar is moved to softmax computation. - /*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale, - params->q_buffer, k, m * k, - params->k_buffer, k, k_buffer_stride, - /*beta=*/0.0f, - gemm1_out, n, m * n, - batch); - } - - inline static Status SoftmaxRawMask(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { - // Softmax on [m,n] along the n dimension. - // Raw attention mask could be 2D (B,S) or 3D (B,S,T) or 4D(B,1,M,M), where M is the max sequence length. - auto attn = params->attention; - auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - T* persistent_softmax_workspace = gemm1_out; // replace Q*K' in place if persistent softmax is selected. - return ComputeSoftmaxWithRawMask( - params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - strides, buffer, nullptr, params->bias_buffer, gemm1_out, softmax_out, - attn->is_unidirectional, /* FIXME: this must not be attn.scale! */ params->scale, - use_persistent_softmax, persistent_softmax_workspace, attn->mask_filter_value); - } - - inline static Status Softmax1DIndexMask(const GemmSoftmaxGemmPermuteParams* params) { - auto mask_1d = params->mask_index_buffer; - auto mask_1d_size = params->mask_index_dims[0]; - // Softmax on [m,n] along the n dimension. - // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - const int* mask_start = (mask_1d_size > attn->batch_size) ? mask_1d + attn->batch_size : nullptr; - return ComputeSoftmaxWithMask1D( - params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - mask_1d, mask_start, params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); - } - - inline static Status SoftmaxNoMask(const GemmSoftmaxGemmPermuteParams* params) { - // Softmax on [m,n] along the n dimension. - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - return ComputeSoftmax( - params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); - } - - inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams* params) { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - - int v_buffer_stride = n * o; - if (IsKVBNMH(params->attention->mode)) { - v_buffer_stride = params->attention->max_sequence_length * params->attention->v_head_size; - } - - // GEMM2 [m,n] * [n,o] -> [m,o] - // semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H. - return blas::row_major::StridedBatchedGemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, o, n, - /*alpha=*/1.0f, - softmax_out, n, m * n, - params->v_buffer, o, v_buffer_stride, - /*beta=*/0.0f, - gemm2_out, o, m * o, - batch); - } - - inline static Status Permute0213(const GemmSoftmaxGemmPermuteParams* params) { - // Permute 0213 - // gemm2_out is B,N,S,H, transpose to out_buffer as B,S,N,H - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - return LaunchTransCtx( - params->StreamHandle(), - attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, - params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer); - } - - static Status GetSupportedStatus(const GemmSoftmaxGemmPermuteParams* params) { - const auto& attn = params->attention; - // TODO: address the BNMH k,v strides - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: - if (attn->qkv_format == Q_K_V_BNSH) { - return Status::OK(); - } else { - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, got ", - attn->qkv_format); - } - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH but k, v are BLNH"); - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - // If sequence_length is 1, query of B1NH can be simply viewed as BN1H. - if (attn->sequence_length == 1) { - return Status::OK(); - } else { - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, ", - "only if sequence_length is 1, query of BSNH can be viewed as BNSH"); - } - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH"); - default: - return TUNABLE_OP_UNSUPPORTED("unknonw"); - } - return TUNABLE_OP_UNSUPPORTED("unknonw case"); - } - - static Status Run(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { - auto supported_status = GetSupportedStatus(params); - if (!supported_status.IsOK()) { - return supported_status; - } - ORT_RETURN_IF_ERROR(Gemm1(params)); - - if (UseRawAttentionMask(params)) { - ORT_RETURN_IF_ERROR(SoftmaxRawMask(params, use_persistent_softmax)); - } else if (params->mask_index_dims.size() == 1) { // 1d index mask - ORT_RETURN_IF_ERROR(Softmax1DIndexMask(params)); - } else { - ORT_RETURN_IF_ERROR(SoftmaxNoMask(params)); - } - - ORT_RETURN_IF_ERROR(Gemm2(params)); - ORT_RETURN_IF_ERROR(Permute0213(params)); - return Status::OK(); - } -}; - -template -class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp> { - public: - GemmSoftmaxGemmPermuteTunableOp(); - - inline static bool IsSupportedMode(const RocmAttentionParameters* attn) { - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: - // depends on qkv format - if (attn->qkv_format == Q_K_V_BNSH || attn->qkv_format == Q_K_V_BSNH) { - return true; - } else { - return false; - } - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - return true; - default: - return false; - } - } - - inline static bool IsSupportedMaskType(const RocmAttentionParameters* attn) { - switch (attn->mask_type) { - case MASK_NONE: - case MASK_2D_DUMMY: - case MASK_2D_KEY_PADDING: - case MASK_3D_ATTENTION: - case MASK_4D_MEGATRON: - return true; - default: - return false; - } - } - - inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { - size_t num_bytes = GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(attn); - -#ifdef USE_COMPOSABLE_KERNEL - if (IsSupportedMaskType(attn)) { - auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn); - num_bytes = std::max(num_bytes, sizeof(T) * sizes.x * sizes.y * sizes.z); - } -#endif - - return num_bytes; - } - - template - __global__ static void ConvertToFilledMaskValue( - T* __restrict__ out, - const int3 out_strides, - const int* __restrict__ mask_buffer, - const int3 mask_lengths, // [B,S,T] - const int3 mask_strides, - Converter cvt) { - const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; - if (global_idx >= mask_lengths.x * mask_lengths.y * CeilDiv(mask_lengths.z, VecSize)) { - return; - } - - const int tidx = (global_idx % CeilDiv(mask_lengths.z, VecSize)) * VecSize; - const int bs_idx = global_idx / CeilDiv(mask_lengths.z, VecSize); - const int sidx = bs_idx % mask_lengths.y; - const int bidx = bs_idx / mask_lengths.y; - - int64_t in_offset = mask_strides.x * bidx + mask_strides.y * sidx + mask_strides.z * tidx; - int64_t out_offset = out_strides.x * bidx + out_strides.y * sidx + out_strides.z * tidx; - - if (tidx + VecSize <= mask_lengths.z) { - using LoadT = const aligned_vector; - using StoreT = aligned_vector; - LoadT load = *reinterpret_cast(mask_buffer + in_offset); - StoreT store; - -#pragma unroll - for (int i = 0; i < VecSize; i++) { - store.val[i] = cvt(load.val[i]); - } - *reinterpret_cast(out + out_offset) = store; - } else { -#pragma unroll - for (int i = 0; i < mask_lengths.z - tidx; i++) { - out[out_offset + i] = cvt(mask_buffer[in_offset + i]); - } - } - } - - static Status LaunchConvertToFilledMaskValue(const GemmSoftmaxGemmPermuteParams* params) { - constexpr const int kThreadPerBlock = 256; - constexpr const int kVecSize = 4; - - auto attn = params->attention; - auto [buffer, lengths, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); - int64_t total_threads = lengths.x * lengths.y * CeilDiv(lengths.z, kVecSize); - auto num_blocks = CeilDiv(total_threads, kThreadPerBlock); - - auto mask_filter_value = attn->mask_filter_value; - auto cvt = [=] __device__(int v) -> T { - return v == 1 ? 0 : mask_filter_value; - }; - - ConvertToFilledMaskValue<<StreamHandle()>>>( - reinterpret_cast(params->workspace_buffer), {lengths.y * lengths.z, lengths.z, 1}, // out desc - buffer, lengths, strides, // mask desc - cvt); - - return HIP_CALL(hipGetLastError()); - } -}; - -#ifdef USE_COMPOSABLE_KERNEL - -template -auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { - constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); - - using Nop = ck::tensor_operation::element_wise::PassThrough; - using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), - "attention mode is not supported, got ", params->attention->mode); - if constexpr (USE_BIAS) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer == nullptr, "biased version only support input with bias"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer != nullptr, "non-biased version only support input without bias"); - } - if constexpr (USE_MASK) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), - "mask type is not supported, got ", params->attention->mask_type); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer == nullptr, "masked version only support input with mask"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); - } - - auto attn = params->attention; - const int& G0 = attn->batch_size; - const int& G1 = attn->num_heads; - const int& M = attn->sequence_length; - const int& N = attn->total_sequence_length; - const int& K = attn->head_size; - const int& O = attn->v_head_size; - { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); - } - - auto [qs, ks, vs] = GetQkvStrides(attn); - std::vector q_buffer_lengths = {G0, G1, M, K}; - std::vector q_buffer_strides = qs.template ForBNSHCoord>(); - std::vector k_buffer_lengths = {G0, G1, N, K}; - std::vector k_buffer_strides = ks.template ForBNSHCoord>(); - std::vector v_buffer_lengths = {G0, G1, O, N}; - std::vector v_buffer_strides = vs.template ForBNHSCoord>(); - std::vector out_buffer_lengths = {G0, G1, M, O}; - std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 - - std::array bias_buffers{}; - std::array, kNumBiasBuffer> bias_lengths{}; - std::array, kNumBiasBuffer> bias_strides{}; - if constexpr (USE_BIAS) { - bias_buffers[0] = const_cast(params->bias_buffer); - bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - bias_strides[0] = {G1 * M * N, M * N, N, 1}; - } - if constexpr (USE_MASK) { - bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; - bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - if (params->mask_index_dims.size() == 2) { // [B,T] - bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; - } else if (params->mask_index_dims.size() == 3) { // [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else { - ORT_ENFORCE(false, "Unreachable"); - } - } - - auto arg = impl->MakeArgumentPointer( - params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, - bias_buffers, // Gemm1 bias, as attention mask - {}, // Gemm2 bias - q_buffer_lengths, q_buffer_strides, - k_buffer_lengths, k_buffer_strides, - v_buffer_lengths, v_buffer_strides, - out_buffer_lengths, out_buffer_strides, - bias_lengths, bias_strides, - {}, - {}, - Nop{}, - Nop{}, - Acc0ElementOp{params->scale}, - Nop{}, - Nop{}); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - - if constexpr (USE_MASK) { - ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); - } - - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); -} - -template -auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using D0DataType = typename ck::detail::tuple_concat< - std::conditional_t, ck::Tuple<>>, - std::conditional_t, ck::Tuple<>>>::type; - - constexpr static auto MaskingSpecMaskDisabled = - ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; - constexpr static auto MaskingSpecMaskOutUpperTriangle = - ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; - - std::vector>>> - ret; - - for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { - auto type_string = impl->GetTypeString(); - - auto invoker = impl->MakeInvokerPointer(); - auto op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GemmSoftmaxGemmPermuteParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); - - return GetArgAndRunInvoker(impl, invoker, params); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); - } - - for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { - auto type_string = impl->GetTypeString(); - - auto invoker = impl->MakeInvokerPointer(); - auto op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GemmSoftmaxGemmPermuteParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->attention->sequence_length != params->attention->total_sequence_length, - "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); - - return GetArgAndRunInvoker(impl, invoker, params); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); - } - - return ret; -} -#endif // USE_COMPOSABLE_KERNEL - -template -GemmSoftmaxGemmPermuteTunableOp::GemmSoftmaxGemmPermuteTunableOp() { - this->RegisterOp([](const GemmSoftmaxGemmPermuteParams* params) { - return GemmSoftmaxGemmPermuteGenericPipeline::Run(params, false); - }); - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } -#endif -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h deleted file mode 100644 index 0aff519d20e99..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "contrib_ops/cpu/bert/attention_common.h" -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - RocmTuningContext* tuning_ctx, // context for tuning - Stream* stream, // ORT Stream - hipblasHandle_t& hipblas, // hipblas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise.h b/onnxruntime/contrib_ops/rocm/bert/elementwise.h deleted file mode 100644 index 768295767835a..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise.h +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchElementwiseKernel(RocmTuningContext* tuning_ctx, Stream* stream, - const T* input, int input_length, - const T* bias, int bias_length, - T* output); - -// The following is LaunchElementwiseKernel implementation detail. Their interfaces are exposed for kernel explorer. -namespace internal { - -template -struct ElementwiseParams : OpParams { - ElementwiseParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, - const T* input, const T* bias, T* output, int input_length, int bias_length) - : OpParams(tuning_ctx, stream), - input(input), - bias(bias), - output(output), - input_length(input_length), - bias_length(bias_length) {} - - std::string Signature() const override { - std::string sig = std::to_string(input_length) + "_" + std::to_string(bias_length); - return sig; - } - - const T* input; - const T* bias; - T* output; - int input_length; - int bias_length; -}; - -template -class ElementwiseOp { - public: - Status operator()(const ElementwiseParams* params); - Status IsSupported(const ElementwiseParams* params); -}; - -template -Status ElementwiseStaticSelection(const ElementwiseParams* params); - -template -class ElementwiseTunableOp : public TunableOp> { - public: - ElementwiseTunableOp(); -}; - -} // namespace internal - -#define ELEMENTWISE_FWD_DECL(FnName, T) \ - namespace functor { \ - struct FnName; \ - } - -ELEMENTWISE_FWD_DECL(FastGeLU, float); -ELEMENTWISE_FWD_DECL(FastGeLU, double); -ELEMENTWISE_FWD_DECL(FastGeLU, half); -ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16); - -ELEMENTWISE_FWD_DECL(GeLU, float); -ELEMENTWISE_FWD_DECL(GeLU, double); -ELEMENTWISE_FWD_DECL(GeLU, half); -ELEMENTWISE_FWD_DECL(GeLU, BFloat16); - -ELEMENTWISE_FWD_DECL(ReLU, float); -ELEMENTWISE_FWD_DECL(ReLU, half); -ELEMENTWISE_FWD_DECL(ReLU, BFloat16); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh deleted file mode 100644 index 8255e70d27e48..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/tunable/util.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "contrib_ops/rocm/bert/elementwise.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace functor { - -struct FastGeLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - constexpr const float b = 0.7978845608028654f; // sqrt(2.0/M_PI) - - // const T cdf = a + a * _Tanh(in * (c * in * in + b)); - const T xb = x * T(b); - const T u = xb * T(0.044715f) * x * x + xb; - const T emu = __expf(-u - u); - const T cdf = T(1.0f) / (T(1.0f) + emu); - y = x * cdf; - } -}; - -struct GeLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - y = T(0.5f) * x * (T(1.f) + T(erf(0.70710678118f * float(x)))); - } -}; - -struct ReLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - y = x >= T{} ? x : T{}; - } -}; - -} // namespace functor - -using onnxruntime::rocm::CeilDiv; -using onnxruntime::rocm::GPU_WARP_SIZE; - -template -__global__ void ElementwiseKernel( - const T* __restrict__ input, int input_length, - const T* __restrict__ bias, int bias_length, - T* __restrict__ output) { - const int idx = blockIdx.x * TPB + threadIdx.x; - Fn f{}; - - if (idx < input_length) { - const T x = input[idx] + (bias == nullptr ? T{} : bias[idx % bias_length]); - f(output[idx], x); - } -} - -template -__global__ void ElementwiseKernelVec( - const T* __restrict__ input, int input_length, - const T* __restrict__ bias, int bias_length, - T* output) { - using VecT = onnxruntime::rocm::aligned_vector; - Fn f{}; - - const int idx = (blockIdx.x * TPB + threadIdx.x) * ILP; - if (idx < input_length) { - T input_v[ILP]; - VecT* input_val = reinterpret_cast(&input_v); - *input_val = *reinterpret_cast(&input[idx]); - T output_v[ILP]; - VecT* output_val = reinterpret_cast(&output_v); - T bias_v[ILP]; - if (bias != nullptr) { - VecT* bias_val = reinterpret_cast(&bias_v); - *bias_val = *reinterpret_cast(&bias[idx % bias_length]); - } - -#pragma unroll - for (int i = 0; i < ILP; i++) { - const T x = (bias == nullptr) ? input_v[i] : (T)(input_v[i] + bias_v[i]); - f(output_v[i], x); - } - *(reinterpret_cast(&output[idx])) = *output_val; - } -} - -template -Status LaunchElementwiseKernel( - RocmTuningContext* tuning_ctx, Stream* stream, - const T* input, int input_length, - const T* bias, int bias_length, - T* output) { - internal::ElementwiseParams params(tuning_ctx, stream, input, bias, output, input_length, bias_length); - if (tuning_ctx->IsTunableOpEnabled()) { - static internal::ElementwiseTunableOp op; - return op(¶ms); - } - - return internal::ElementwiseStaticSelection(¶ms); -} - -namespace internal { - -template -Status ElementwiseOp::operator()(const ElementwiseParams* params) { - dim3 blocks(CeilDiv(params->input_length, ThreadsPerBlock * VecSize)); - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, - params->bias, params->bias_length, - params->output); - return HIP_CALL(hipGetLastError()); -} - -template -Status ElementwiseOp::IsSupported(const ElementwiseParams* params) { - // TODO(anyone): Add tail handling for FastGelu - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) || - (params->bias_length == 0 && params->input_length % VecSize == 0))); - // Avoid redundant configurations - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->input_length > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)); - - return Status::OK(); -} - -template -Status ElementwiseStaticSelection(const ElementwiseParams* params) { - constexpr int block_size = 256; - if constexpr (std::is_same_v) { - if (params->bias != nullptr) { - if (0 == (params->bias_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 - const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->bias_length % 4)) { - const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->bias_length % 2)) { - const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - } else { - if (0 == (params->input_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 - const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->input_length % 4)) { - const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->input_length % 2)) { - const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - } - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - return HIP_CALL(hipGetLastError()); -} - -template -ElementwiseTunableOp::ElementwiseTunableOp() { - this->RegisterOp(ElementwiseStaticSelection); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); -} - -#undef ADD_OP - -} // namespace internal - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime - -#define ELEMENTWISE_KERNEL_IMPL(Fn, T) \ - namespace onnxruntime { \ - namespace contrib { \ - namespace rocm { \ - template Status LaunchElementwiseKernel( \ - RocmTuningContext * tuning_ctx, Stream* stream, \ - const T* input, int input_length, \ - const T* bias, int bias_length, \ - T* output); \ - namespace internal { \ - template class ElementwiseTunableOp; \ - } \ - } \ - } \ - } diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu deleted file mode 100644 index c2a670ea76aca..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu deleted file mode 100644 index 97f0f74640c6e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu deleted file mode 100644 index 67e50869133f5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc deleted file mode 100644 index fdb62d3a2aec5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/gemm_fast_gelu.h" - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" -#include "core/providers/cpu/math/matmul_helper.h" -#include "core/providers/rocm/rocm_common.h" - -using onnxruntime::rocm::ToHipType; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GemmFastGelu, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - GemmFastGelu); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) - -template -Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { - typedef typename ToHipType::MappedType HipT; - - const auto* X = ctx->Input(0); - const auto* W = ctx->Input(1); - const auto* bias = ctx->Input(2); - - bool transa = false; - bool transb = false; - bool trans_batch_a = false; - bool trans_batch_b = false; - - MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(X->Shape(), W->Shape(), transa, transb, trans_batch_a, trans_batch_b, false)); - - Tensor* Y = ctx->Output(0, helper.OutputShape()); - - // Bail out early if the output is going to be empty - if (Y->Shape().Size() == 0) - return Status::OK(); - - // gemmfastgelu only support alpha == 1 and beta == 0 - const HipT alpha = ToHipType::FromFloat(1.0f); - const HipT beta = ToHipType::FromFloat(0.0f); - - using onnxruntime::rocm::tunable::blas::BlasOp; - - return blas::row_major::GemmFastGelu( - GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), - transa ? BlasOp::Trans : BlasOp::NonTrans, - transb ? BlasOp::Trans : BlasOp::NonTrans, - helper.M(), helper.N(), helper.K(), - alpha, - reinterpret_cast(X->Data()), helper.Lda(transa), - reinterpret_cast(W->Data()), helper.Ldb(transb), - (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, - beta, - reinterpret_cast(Y->MutableData()), helper.Ldc()); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h deleted file mode 100644 index ae4f84fa5f033..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using onnxruntime::rocm::RocmKernel; - -template -class GemmFastGelu final : public RocmKernel { - public: - GemmFastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {} - Status ComputeInternal(OpKernelContext* ctx) const override; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh deleted file mode 100644 index 77f53f9eed027..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#ifdef USE_COMPOSABLE_KERNEL -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" - -using onnxruntime::rocm::ToHipType; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { -namespace internal { - -#ifdef USE_COMPOSABLE_KERNEL - -using onnxruntime::rocm::CKBlasOpAdaptor; -using onnxruntime::rocm::CKDataTypeAdaptor; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using Nop = ck::tensor_operation::element_wise::PassThrough; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; -using FastGelu = ck::tensor_operation::element_wise::FastGelu; - -template -auto GetCKGemmAddFastGeluTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< - ALayout, BLayout, ck::Tuple, Row, - CKDataType, CKDataType, ck::Tuple, CKDataType, - Nop, Nop, AddFastGelu>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = onnxruntime::MakeString("withbias ", impl->GetTypeString()); - - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero || params->bias == nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); - - auto nop = Nop{}; - auto addfastgelu = AddFastGelu{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, std::array{params->bias}, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, std::array{0}, params->ldc, - nop, nop, addfastgelu); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); - } - return ret; -} - -template -auto GetCKGemmFastGeluTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< - ALayout, BLayout, ck::Tuple<>, Row, - CKDataType, CKDataType, ck::Tuple<>, CKDataType, - Nop, Nop, FastGelu>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = onnxruntime::MakeString("nobias ", impl->GetTypeString()); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero || params->bias != nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); - - auto nop = Nop{}; - auto fastgelu = FastGelu{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, - {}, - params->c, - params->m, params->n, params->k, - params->lda, params->ldb, - {}, - params->ldc, - nop, nop, fastgelu); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); - } - return ret; -} -#else -struct Row {}; -struct Col {}; -#endif // USE_COMPOSABLE_KERNEL - -} // namespace internal -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h deleted file mode 100644 index 2b8a21b83f177..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::blas::BlasOp; -using onnxruntime::rocm::tunable::blas::BlasOpToString; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -template -struct GemmFastGeluParams : OpParams { - std::string Signature() const override { - bool has_bias = (nullptr != bias) ? 0 : 1; - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias); - } - hipblasHandle_t handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - T alpha; - const T* a; - int64_t lda; - const T* b; - int64_t ldb; - const T* bias; - T beta; - T* c; - int64_t ldc; -}; - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu deleted file mode 100644 index 8d7e64b1015be..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES -#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" - -#include -#include - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" -#include "core/providers/rocm/shared_inc/fpgeneric.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -namespace row_major { - -template -inline GEMMFASTGELU(T, ScalarT) { - GemmFastGeluParams params; - params.tuning_ctx = tuning_ctx; - params.stream = stream; - params.handle = handle; - - params.opa = opa; - params.opb = opb; - params.m = m; - params.n = n; - params.k = k; - if constexpr (!std::is_same_v && std::is_same_v) { - params.alpha = ToHipType::FromFloat(std::forward(alpha)); - } else { - params.alpha = alpha; - } - params.a = a; - params.lda = lda; - params.b = b; - params.ldb = ldb; - params.bias = bias; - if constexpr (!std::is_same_v && std::is_same_v) { - params.beta = ToHipType::FromFloat(std::forward(beta)); - } else { - params.beta = beta; - } - params.c = c; - params.ldc = ldc; - - if (tuning_ctx->IsTunableOpEnabled()) { - if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } - } - - return internal::GemmFastGeluUnfused(¶ms); -} - -#define CALL_GEMMFASTGELU(T, ScalarT) \ - GemmFastGelu(tuning_ctx, stream, handle, \ - opa, opb, \ - m, n, k, \ - alpha, a, lda, b, ldb, bias, \ - beta, c, ldc) - -// clang-format off -GEMMFASTGELU(float, float ) { return CALL_GEMMFASTGELU(float, float ); } -GEMMFASTGELU(half, half ) { return CALL_GEMMFASTGELU(half, half ); } -GEMMFASTGELU(BFloat16, BFloat16) { return CALL_GEMMFASTGELU(BFloat16, BFloat16); } -GEMMFASTGELU(half, float ) { return CALL_GEMMFASTGELU(half, float ); } -GEMMFASTGELU(BFloat16, float ) { return CALL_GEMMFASTGELU(BFloat16, float ); } -// clang-format on - -#undef CALL_GEMMFASTGELU - -} // namespace row_major - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h deleted file mode 100644 index b707c63ef44be..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "core/common/status.h" -#include "core/common/float16.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -#define GEMMFASTGELU(T, ScalarT) \ - common::Status GemmFastGelu( \ - RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ - BlasOp opa, BlasOp opb, \ - std::int64_t m, std::int64_t n, std::int64_t k, \ - ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ - const T* bias, ScalarT beta, T* c, std::int64_t ldc) - -namespace row_major { - -GEMMFASTGELU(float, float); -GEMMFASTGELU(half, half); -GEMMFASTGELU(BFloat16, BFloat16); -GEMMFASTGELU(half, float); -GEMMFASTGELU(BFloat16, float); - -} // namespace row_major - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime - -#ifndef _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES -#undef GEMMFASTGELU -#endif diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh deleted file mode 100644 index e157aa57f8c43..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "contrib_ops/rocm/bert/elementwise.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/gemm_hipblaslt.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { -namespace internal { - -using namespace onnxruntime::rocm::tunable::blas::internal; - -template -Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { - namespace column_major = onnxruntime::rocm::tunable::blas::column_major; - ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning_ctx, params->stream, params->handle, - params->opb, params->opa, - params->n, params->m, params->k, - params->alpha, params->b, params->ldb, params->a, params->lda, - params->beta, params->c, params->ldc)); - - int64_t fast_gelu_input_length = params->m * params->n; - int64_t bias_length = (params->bias != nullptr) ? params->n : 0; - - // Because of GemmFastGeluUnfused is a combination of GemmOp and FastGeluOp, FastGeluOp in this combination is - // an inplace computation. - // 1. If we call GemmFastGeluUnfused directly with enabled tuning, it may cause the input buffer of FastGelu been - // updated accumulatedly and result in incorrect result finally. This only happens if the tuning's FindFastest is invoked. - // 2. It's safe to call GemmFastGeluUnfused with disabled tuning, FastGelu only run once and produce correct result. - // 3. It's safe to call GemmFastGeluUnfused as part of GemmFastGeluTunableOp with enable tuning, GemmTunableOp and - // FastGeluTunableOp will do tune in first warmup step separately during GemmFastGeluUnfused profiling process. - // After that, the call to GemmFastGeluUnfused not invoke tuning's FindFastest of FastGelu. - // - // Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp - // to protect original input value. - return onnxruntime::contrib::rocm::LaunchElementwiseKernel( - params->tuning_ctx, params->Stream(), - params->c, static_cast(fast_gelu_input_length), - params->bias, static_cast(bias_length), - params->c); -} - -template -class GemmFastGeluTunableOp : public TunableOp> { - public: - GemmFastGeluTunableOp() { - this->RegisterOp(GemmFastGeluUnfused); -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - -#ifdef USE_HIPBLASLT - for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } -}; - -} // namespace internal -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu deleted file mode 100644 index 09a6550549614..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ /dev/null @@ -1,530 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/platform/env_var_utils.h" -#include "contrib_ops/rocm/bert/group_query_attention.h" -#include "contrib_ops/cpu/bert/group_query_attention_helper.h" -#include "contrib_ops/rocm/bert/rotary_embedding_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" - -#ifdef USE_COMPOSABLE_KERNEL_CK_TILE -#include "ck_tile/core/numeric/integer.hpp" -#include "fmha_fwd.hpp" -#endif - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GroupQueryAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ - .MayInplace(3, 1) \ - .MayInplace(4, 2) \ - .InputMemoryType(OrtMemTypeCPUInput, 6), \ - GroupQueryAttention); - -// REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -// REGISTER_KERNEL_TYPED(BFloat16) - -template -std::string GetCkFmhaDataTypeString(); - -template <> -std::string GetCkFmhaDataTypeString() { - return "fp16"; -} - -template <> -std::string GetCkFmhaDataTypeString() { - return "bf16"; -} - -__global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int inc) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx < num_elems) { - out[idx] = seqlens[idx] + inc; - } -} - -Status LaunchSeqlensInc(hipStream_t stream, const int* seqlens, int* out, int num_elems, int inc) { - constexpr int NumThreads = 128; - int num_blks = CeilDiv(num_elems, NumThreads); - seqlens_inc_kernel<<>>(seqlens, out, num_elems, inc); - return HIP_CALL(hipGetLastError()); -} - -__global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx < num_elems) { - out[idx] = idx * length_per_seq; - } - if (idx == 0) { - out[num_elems] = num_elems * length_per_seq; - } -} - -Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int length_per_seq) { - constexpr int NumThreads = 128; - int num_blks = CeilDiv(num_elems, NumThreads); - seqstart_init_kernel<<>>(out, num_elems, length_per_seq); - return HIP_CALL(hipGetLastError()); -} - -// Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, - const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - int b = tid / seqlen; - int s = tid % seqlen; - if (b < batch_size) { - if (s < seqlens_k[b] + 1) { - position_ids[tid] = s; - } else { - position_ids[tid] = 1; - } - } -} - -// Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid < batch_size) { - position_ids[tid] = seqlens_k[tid]; - } -} - -// Convert seqlens_k to position_ids -Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, - int64_t* position_ids, hipStream_t stream, const int max_threads_per_block) { - const int seqlen = parameters.sequence_length; - const int batch_size = parameters.batch_size; - const int threads = max_threads_per_block; - const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_first_prompt) { - SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); - } else { - SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); - } - return HIP_CALL(hipGetLastError()); -} - -template -GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) - : RocmKernel(info) { - int64_t num_heads = 0; - int64_t kv_num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); - num_heads_ = static_cast(num_heads); - kv_num_heads_ = static_cast(kv_num_heads); - is_past_bsnh_ = false; - is_unidirectional_ = true; - local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); - do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; - rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; - scale_ = info.GetAttrOrDefault("scale", 0.0f); -} - -template <> -std::once_flag GroupQueryAttention::arch_checking_{}; - -template <> -std::once_flag GroupQueryAttention::arch_checking_{}; - -template -Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { -#if USE_COMPOSABLE_KERNEL_CK_TILE - auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); - const Tensor* query = ctx->Input(0); - const Tensor* key = ctx->Input(1); - const Tensor* value = ctx->Input(2); - const Tensor* past_key = ctx->Input(3); - const Tensor* past_value = ctx->Input(4); - const Tensor* seqlens_k = ctx->Input(5); - const Tensor* total_seqlen = ctx->Input(6); - const Tensor* cos_cache = ctx->Input(7); - const Tensor* sin_cache = ctx->Input(8); - - auto& device_prop = GetDeviceProp(); - std::call_once( - arch_checking_, - [](const hipDeviceProp_t& device_prop) { - if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && - std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { - LOGS_DEFAULT(WARNING) - << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " - << "CDNA2 and CDNA3 archs."; - LOGS_DEFAULT(WARNING) - << "GroupQueryAttention running on an unsuppoted GPU may result in " - << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; - } - }, - device_prop); - - GroupQueryAttentionParameters parameters; - using HipT = typename ToHipType::MappedType; - - const int max_thr_per_blk = device_prop.maxThreadsPerBlock; - - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, - key, - value, - past_key, - past_value, - cos_cache, - sin_cache, - ¶meters, - num_heads_, - kv_num_heads_, - seqlens_k, - total_seqlen, - is_past_bsnh_, - scale_, - max_thr_per_blk)); - - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; - const int kv_num_heads = parameters.kv_num_heads; - const int head_size = parameters.head_size; - AttentionQkvFormat past_kv_format = parameters.past_kv_format; - - parameters.local_window_size = local_window_size_; - parameters.is_unidirectional = is_unidirectional_; - // parameters.zeros_count = kZerosCount; - // parameters.zero_ptr = zeros_.get(); - // parameters.left_padding = left_padding_; - parameters.do_rotary = do_rotary_; - parameters.rotary_interleaved = rotary_interleaved_; - - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( - context->OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); - - if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); - } - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(batch_size); - output_shape[1] = static_cast(sequence_length); - output_shape[2] = static_cast(parameters.hidden_size); - Tensor* output = ctx->Output(0, output_shape); - Strides output_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - - int4 past_shape; - std::vector present_dims; - Strides present_strides; - Strides past_strides; - if (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { - past_shape = { - batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size}; - past_strides = Strides::BSNHMemory( - batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size); - present_dims = { - batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size}; - present_strides = Strides::BSNHMemory( - batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); - } else { // BNSH - past_shape = { - batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size}; - past_strides = Strides::BNSHMemory( - batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size); - present_dims = { - batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size}; - present_strides = Strides::BNSHMemory( - batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size); - } - TensorShape present_shape(present_dims); - Tensor* present_key = ctx->Output(1, present_shape); - Tensor* present_value = ctx->Output(2, present_shape); - - Strides query_strides; - Strides key_strides; - Strides value_strides; - int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; // BNSH coord - const HipT* query_ptr = reinterpret_cast(query->DataRaw()); - const HipT* key_ptr; - const HipT* value_ptr; - if (!parameters.is_packed_qkv) { - query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - key_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); - value_strides = key_strides; - key_ptr = reinterpret_cast(key->DataRaw()); - value_ptr = reinterpret_cast(value->DataRaw()); - } else { - query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); - key_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); - value_strides = query_strides; - const size_t key_offset = static_cast(num_heads * head_size); - const size_t value_offset = static_cast(kv_num_heads * head_size); - key_ptr = query_ptr + key_offset; - value_ptr = key_ptr + value_offset; - } - - IAllocatorUniquePtr rotary_q_tmp; - IAllocatorUniquePtr rotary_k_tmp; - if (parameters.do_rotary) { - size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); - size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); - auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); - - rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); - rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); - auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, - reinterpret_cast(seqlens_k->DataRaw()), - reinterpret_cast(rotary_position_ids_tmp.get()), - hip_stream, max_thr_per_blk)); - // Launch rotary embedding kernel - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_q_tmp.get(), query_ptr, - reinterpret_cast(rotary_position_ids_tmp.get()), - reinterpret_cast(cos_cache->DataRaw()), - reinterpret_cast(sin_cache->DataRaw()), - parameters.batch_size, parameters.sequence_length, - parameters.num_heads, parameters.head_size, - parameters.rotary_dim, parameters.seqlen_present_kv_cache, - /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, - query_strides.ForBNSHCoord(), - rotary_q_strides.ForBNSHCoord())); - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, - reinterpret_cast(rotary_position_ids_tmp.get()), - reinterpret_cast(cos_cache->DataRaw()), - reinterpret_cast(sin_cache->DataRaw()), - parameters.batch_size, parameters.sequence_length, - parameters.kv_num_heads, parameters.head_size, - parameters.rotary_dim, parameters.seqlen_present_kv_cache, - /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, - key_strides.ForBNSHCoord(), - rotary_k_strides.ForBNSHCoord())); - query_ptr = reinterpret_cast(rotary_q_tmp.get()); - key_ptr = reinterpret_cast(rotary_k_tmp.get()); - query_strides = rotary_q_strides; - key_strides = rotary_k_strides; - } - - const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; - IAllocatorUniquePtr seqlens_k_tmp; - - // build present kv cache - auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); - auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); - if (parameters.is_first_prompt) { - // copy prompt kv to present kv - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), - present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), - present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - } else { - const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); - const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); - parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: - if (!parameters.kv_share_buffer) { - // copy past to present, - // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are - // not the same, aka, can not be as simple as strided - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_key_ptr, past_shape, past_strides.ForBNSHCoord(), - present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), - present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - } else { - // In the case of share buffer - ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); - ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); - } - // then append new kv to present - size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, - present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, - max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, - present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, - max_thr_per_blk)); - - // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. - // we should call fmha with total sequence lengths - seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); - seqlens_k_ptr = seqlens_k_tmp.get(); - } - static_assert(std::is_same_v); - - const float scale = parameters.scale == 0.0f - ? 1.f / sqrt(static_cast(parameters.head_size)) - : parameters.scale; - bias_enum bias_type = bias_enum::no_bias; - - mask_info mask = [&]() { - if (local_window_size_ != -1) { - mask_info ret; - ret.type = mask_enum::window_generic; - ret.left = local_window_size_; - ret.right = parameters.is_unidirectional ? 0 : -1; - // ret.x = kv_sequence_length - (sequence_length - ret.left); - // ret.y = sequence_length + (ret.right - kv_sequence_length); - return ret; - } - - if (parameters.is_first_prompt && is_unidirectional_) { - return mask_info::decode("t", sequence_length, kv_sequence_length); - } - - return mask_info::decode("0", sequence_length, kv_sequence_length); - }(); - - auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); - auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit( - hip_stream, seqstart_q_tmp.get(), batch_size, - query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit( - hip_stream, seqstart_k_tmp.get(), batch_size, - present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); - - fmha_fwd_args args{ - query_ptr, - present_key->DataRaw(), - present_value->DataRaw(), - nullptr, // bias, alibi/element - nullptr, // lse, logsumexp buffer - output->MutableDataRaw(), - seqstart_q_tmp.get(), // seqstart_q_ptr, for group mode - seqstart_k_tmp.get(), // seqstart_k_ptr, for group mode - seqlens_k_ptr, // seqlen_k_ptr, for group mode - sequence_length, // seqlen_q, for batch mode - kv_sequence_length, // seqlen_k, for batch mode - parameters.batch_size, // batch - parameters.sequence_length, // max_seqlen_q - parameters.head_size, // hdim_q - parameters.head_size, // hdim_v - parameters.num_heads, - parameters.kv_num_heads, - scale, - 1.0f, // scale_p of squant, useless - 1.0f, // scale_o of squant, useless - static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S - static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S - static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S - batch_size, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 - static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S - static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N - static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N - static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N - 0, // nhead_stride_bias - batch_size, // nhead_stride_lse - static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B - static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B - static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B - static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B - 0, // batch_stride_bias - num_heads * batch_size, // batch_stride_lse - static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B - mask.left, // window_size_left - mask.right, // window_size_right - static_cast(mask.type)}; - -#if 0 - std::cout - << "\n sequence_length:" << sequence_length - << "\n kv_sequence_length:" << kv_sequence_length - << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache - << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; - - std::cout - << "\n q_ptr:" << args.q_ptr - << "\n k_ptr:" << args.k_ptr - << "\n v_ptr:" << args.v_ptr - << "\n bias_ptr:" << args.bias_ptr - << "\n lse_ptr:" << args.lse_ptr - << "\n o_ptr:" << args.o_ptr - << "\n seqstart_q_ptr:" << args.seqstart_q_ptr - << "\n seqstart_k_ptr:" << args.seqstart_k_ptr - << "\n seqlen_k_ptr:" << args.seqlen_k_ptr - << "\n seqlen_q:" << args.seqlen_q - << "\n seqlen_k:" << args.seqlen_k - << "\n batch:" << args.batch - << "\n max_seqlen_q:" << args.max_seqlen_q - << "\n hdim_q:" << args.hdim_q - << "\n hdim_v:" << args.hdim_v - << "\n nhead_q:" << args.nhead_q - << "\n nhead_k:" << args.nhead_k - << "\n scale_s:" << args.scale_s - << "\n scale_p:" << args.scale_p - << "\n scale_o:" << args.scale_o - << "\n stride_q:" << args.stride_q - << "\n stride_k:" << args.stride_k - << "\n stride_v:" << args.stride_v - << "\n stride_bias:" << args.stride_bias - << "\n stride_o:" << args.stride_o - << "\n nhead_stride_q:" << args.nhead_stride_q - << "\n nhead_stride_k:" << args.nhead_stride_k - << "\n nhead_stride_v:" << args.nhead_stride_v - << "\n nhead_stride_bias:" << args.nhead_stride_bias - << "\n nhead_stride_lse:" << args.nhead_stride_lse - << "\n nhead_stride_o:" << args.nhead_stride_o - << "\n batch_stride_q:" << args.batch_stride_q - << "\n batch_stride_k:" << args.batch_stride_k - << "\n batch_stride_v:" << args.batch_stride_v - << "\n batch_stride_bias:" << args.batch_stride_bias - << "\n batch_stride_lse:" << args.batch_stride_lse - << "\n batch_stride_o:" << args.batch_stride_o - << "\n window_size_left:" << args.window_size_left - << "\n window_size_right:" << args.window_size_right - << "\n mask_type:" << args.mask_type - << std::endl; -#endif - - fmha_fwd_traits traits{ - parameters.head_size, - parameters.head_size, // v head size - GetCkFmhaDataTypeString(), - !parameters.is_first_prompt, // true, // is_group_mode - true, // is_v_rowmajor ? dim is fastest : seq is fastest - mask.type, - bias_type, - false, // has_lse - false, // do_fp8_static_quant, aka, squant - }; - - ck_tile::stream_config stream_config{ - hip_stream, - false // time_kernel - }; - - auto duration = fmha_fwd(traits, args, stream_config); - if (duration < 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "fmha_fwd internal error"); - } - HIP_RETURN_IF_ERROR(hipGetLastError()); - - return Status::OK(); -#else - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tile to be enabled"); -#endif -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h deleted file mode 100644 index ce0de1f761aa5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class GroupQueryAttention final : public RocmKernel { - public: - GroupQueryAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - int num_heads_; // number of attention heads - int kv_num_heads_; // different for k and v for group query attention - int local_window_size_; - bool is_unidirectional_; - bool is_past_bsnh_; - bool do_rotary_; - bool rotary_interleaved_; - float scale_; - - private: - static std::once_flag arch_checking_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh deleted file mode 100644 index 2eeb7c3e8f279..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh +++ /dev/null @@ -1,270 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on bert plugins in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#pragma once - -#include -#include -#include -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/shared_inc/rocm_call.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -__device__ inline T Rsqrt(const T& x); - -template <> -__device__ inline float Rsqrt(const float& x) { - return rsqrtf(x); -} - -template <> -__device__ inline half Rsqrt(const half& x) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - return hrsqrt(x); -#else - return half(rsqrtf(static_cast(x))); -#endif -} - -__device__ inline half2 AddHalf2(const half2 a, const half2 b) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - return __hadd2(a, b); -#else - return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y)); -#endif -} - -struct KeyValuePairSum { - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - return hipcub::KeyValuePair(a.key + b.key, a.value + b.value); - } - - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - const half2 a2 = __halves2half2(a.key, a.value); - const half2 b2 = __halves2half2(b.key, b.value); - const half2 res = AddHalf2(a2, b2); - return hipcub::KeyValuePair(__low2half(res), __high2half(res)); - } - - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - return hipcub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); - } -}; - -template -__device__ inline void LayerNorm( - const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = static_cast(output[idx]); - const U g = static_cast(gamma[i]); - const U b = (nullptr == beta) ? U(0.f) : static_cast(beta[i]); - output[idx] = static_cast(g * (val - mu) * rsigma + b); - } -} - -template -__device__ inline void SimplifiedLayerNorm( - const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = static_cast(output[idx]); - const U g = static_cast(gamma[i]); - output[idx] = static_cast(g * val * rsigma); - } -} - -template -__device__ inline void SimplifiedLayerNormVec( - const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - const VecV gamma_v = *reinterpret_cast(gamma + i); - VecV output_v = *reinterpret_cast(output + idx); - -#pragma unroll - for (int k = 0; k < ILP; k++) { - output_v.val[k] = U(gamma_v.val[k]) * U(output_v.val[k]) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } - } -} - -template -__device__ inline void LayerNormVec( - const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + i) : VecV(); - const VecV gamma_v = *reinterpret_cast(gamma + i); - VecV output_v = *reinterpret_cast(output + idx); - -#pragma unroll - for (int k = 0; k < ILP; k++) { - output_v.val[k] = (beta != nullptr) ? U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma + U(beta_v.val[k]) : U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } - } -} - -template -__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair& thread_data, - const int ld, const int idx, const V* beta, const V* gamma, - const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - // Small settings: the block covers the leading dimension TPB >= ld. The input - // value is available in a register - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const hipcub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + threadIdx.x * ILP) : VecV(); - const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); - VecV output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - output_v.val[i] = (beta != nullptr) ? U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma + U(beta_v.val[i]) : U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } -} - -template -__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const U& thread_data, const int ld, const int idx, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - // Small settings: the block covers the leading dimension TPB >= ld. The input - // value is available in a register - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); - VecV output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - output_v.val[i] = U(gamma_v.val[i]) * U(input_v[i]) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu deleted file mode 100644 index 5d4ef53b8ba97..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/multihead_attention.h" - -#include "contrib_ops/cpu/bert/multihead_attention_helper.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" -#include "core/platform/env_var_utils.h" -#include "core/providers/rocm/rocm_common.h" - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_MHA_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - MultiHeadAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - MultiHeadAttention) - -REGISTER_MHA_KERNEL_TYPED(float); -REGISTER_MHA_KERNEL_TYPED(MLFloat16); - -static constexpr int kPastSequenceLengthInputIndex = 7; -static constexpr int kBeamWidthInputIndex = 8; -static constexpr int kPastInputIndex = 5; -static constexpr int kPresentOutputIndex = 1; - -#define REGISTER_DMMHA_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - DecoderMaskedMultiHeadAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(kPastInputIndex, kPresentOutputIndex) \ - .MayInplace(kPastInputIndex + 1, kPresentOutputIndex + 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \ - .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \ - MultiHeadAttention) - -REGISTER_DMMHA_KERNEL_TYPED(float); -REGISTER_DMMHA_KERNEL_TYPED(MLFloat16); - -template -MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) - : RocmKernel(info), - attn_type_(info.node().OpType() == "DecoderMaskedMultiHeadAttention" ? kDecoderMaskedMultiHeadAttention - : kMultiHeadAttention) { - int64_t num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - num_heads_ = static_cast(num_heads); - - mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - - scale_ = info.GetAttrOrDefault("scale", 0.0f); - - past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - tunable_op_ = std::make_shared(); -} - -template -Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { - ORT_ENFORCE( - GetTuningContext()->IsTunableOpEnabled(), - "MultiHeadAttention of ROCm EP is only supported if tunable op is used and tuning is enabled."); - - const Tensor* query = context->Input(0); - const Tensor* key = context->Input(1); - const Tensor* value = context->Input(2); - - const Tensor* bias{}; - const Tensor* key_padding_mask{}; - const Tensor* attention_bias{}; - const Tensor* past_key{}; - const Tensor* past_value{}; - const Tensor* past_seq_len{}; - - const Tensor* cache_indirection = nullptr; - - if (attn_type_ == kMultiHeadAttention) { - bias = context->Input(3); - key_padding_mask = context->Input(4); - attention_bias = context->Input(5); - past_key = context->Input(6); - past_value = context->Input(7); - } else if (attn_type_ == kDecoderMaskedMultiHeadAttention) { - key_padding_mask = context->Input(3); - attention_bias = context->Input(4); - past_key = context->Input(5); - past_value = context->Input(6); - past_seq_len = context->Input(kPastSequenceLengthInputIndex); - // const Tensor* beam_width = context->Input(8); // NOTE: not used - // const Tensor* cache_indirection = context->Input(9); // TODO: should not present for ROCm EP - bias = context->Input(10); - } - - if (nullptr != bias) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "qkv_bias is not supported on ROCm EP. " - "User should fuse the qkv bias to qkv projection instead."); - } - - auto& device_prop = GetDeviceProp(); - RocmAttentionParameters attn; - ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, - key, - value, - bias, - key_padding_mask, - attention_bias, - past_key, - past_value, - cache_indirection, - past_seq_len, - &attn, /* parameters */ - num_heads_, - mask_filter_value_, - scale_, - is_unidirectional_, - past_present_share_buffer_, - attn_type_, - device_prop.maxThreadsPerBlock)); - - if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); - } - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(attn.batch_size); - output_shape[1] = static_cast(attn.sequence_length); - output_shape[2] = static_cast(attn.v_hidden_size); - Tensor* output = context->Output(0, output_shape); - - std::vector present_dims{ - attn.batch_size, - attn.num_heads, - past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, - attn.head_size, - }; - TensorShape present_shape(present_dims); - Tensor* present_key = context->Output(1, present_shape); - Tensor* present_value = context->Output(2, present_shape); - - ORT_RETURN_IF_ERROR(ClassifyAttentionMode( - attn_type_, &attn, - /*qkv=*/{query, key, value}, - /*past=*/{past_key, past_value}, - /*present=*/{present_key, present_value})); - - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn); - auto workspace = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - hipStream_t stream = Stream(context); - if (nullptr != present_key) { // process past present concat - Strides dst_strides; - - int4 past_shape; - Strides past_src_strides; - const HipT* past_key_src; - const HipT* past_value_src; - HipT* past_key_dst{}; - HipT* past_value_dst{}; - - int4 add_shape; - Strides add_src_strides; - const HipT* add_key_src = reinterpret_cast(key->DataRaw()); - const HipT* add_value_src = reinterpret_cast(value->DataRaw()); - HipT* add_key_dst; - HipT* add_value_dst; - - if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH || - attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - - past_shape = {attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; - past_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); - past_key_src = reinterpret_cast(past_key->DataRaw()); - past_value_src = reinterpret_cast(past_value->DataRaw()); - past_key_dst = reinterpret_cast(present_key->MutableDataRaw()); - past_value_dst = reinterpret_cast(present_value->MutableDataRaw()); - - if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH || - attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - - if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else if ( - attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || - attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || - attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH || - attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - - if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "past present concatenation is not implemented for attention mode ", attn.mode); - } - add_shape = {attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size}; // kernel in coord (b,n,s,h) - add_key_dst = reinterpret_cast(present_key->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - add_value_dst = reinterpret_cast(present_value->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - - if (past_key_dst) { - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, past_key_src, past_shape, past_src_strides.ForBNSHCoord(), - past_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - if (past_value_dst) { - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, past_value_src, past_shape, past_src_strides.ForBNSHCoord(), - past_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, add_key_src, add_shape, add_src_strides.ForBNSHCoord(), - add_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, add_value_src, add_shape, add_src_strides.ForBNSHCoord(), - add_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - - GemmSoftmaxGemmPermuteParams params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = GetHipblasHandle(context); - params.attention = &attn; - params.device_prop = &device_prop; - params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_; - std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews( - &attn, - nullptr == query ? nullptr : reinterpret_cast(query->DataRaw()), - nullptr == key ? nullptr : reinterpret_cast(key->DataRaw()), - nullptr == value ? nullptr : reinterpret_cast(value->DataRaw()), - nullptr == present_key ? nullptr : reinterpret_cast(present_key->DataRaw()), - nullptr == present_value ? nullptr : reinterpret_cast(present_value->DataRaw())); - params.out_buffer = reinterpret_cast(output->MutableDataRaw()); - - if (key_padding_mask != nullptr) { - params.mask_index_buffer = key_padding_mask->Data(); - params.mask_index_dims = key_padding_mask->Shape().AsShapeVector(); - } - - if (attention_bias != nullptr) { - params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); - } - - params.workspace_buffer = reinterpret_cast(workspace.get()); - return (*std::static_pointer_cast(tunable_op_))(¶ms); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h deleted file mode 100644 index 1d676d7a7bcac..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/rocm/bert/attention_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class MultiHeadAttention final : public RocmKernel { - public: - MultiHeadAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - AttentionType attn_type_; - int num_heads_; // number of attention heads - float mask_filter_value_; - float scale_; - bool past_present_share_buffer_{false}; - bool is_unidirectional_{false}; - - // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: - // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. - // 2. We don't want to construct the object repeatly (which is expansive) during Compute. - std::shared_ptr tunable_op_; -}; - -template -class DecoderMaskedMultiHeadAttention final : public RocmKernel { - public: - DecoderMaskedMultiHeadAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - AttentionType mha_type; - int num_heads_; // number of attention heads - float mask_filter_value_; - float scale_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc deleted file mode 100644 index 9e649fb591896..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/skip_layer_norm.h" - -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - SkipLayerNormalization, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - SkipSimplifiedLayerNormalization, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -using namespace ONNX_NAMESPACE; - -template -SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { - ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); - ORT_ENFORCE(epsilon_ >= 0); -} - -template -Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { - const Tensor* input = ctx->Input(0); - const Tensor* skip = ctx->Input(1); - const Tensor* gamma = ctx->Input(2); - - const Tensor* beta = Simplified ? nullptr : ctx->Input(3); - const Tensor* bias = Simplified ? ctx->Input(3) : ctx->Input(4); - - Tensor* output = ctx->Output(0, input->Shape()); - - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); - - if (input->Shape() != skip->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "skip is expected to have same shape as input"); - } - - if (input->Shape().Size() == 0) { - return Status::OK(); - } - - const auto& input_dims = input->Shape().GetDims(); - size_t input_dims_size = input_dims.size(); - if (input_dims_size != 3 && input_dims_size != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size); - } - - int hidden_size = static_cast(input_dims[input_dims_size - 1]); - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of gamma and input does not match"); - } - - if (nullptr != beta) { - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of beta and input does not match"); - } - } - - if (nullptr != bias) { - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "bias is expected to have 1 dimension, got ", bias_dims.size()); - } - if (bias_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of bias and input does not match"); - } - } - - int64_t element_count = input->Shape().Size(); - typedef typename ToHipType::MappedType HipT; - - return LaunchSkipLayerNormKernel( - GetTuningContext(), - ctx->GetComputeStream(), - reinterpret_cast(output->MutableData()), - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, - reinterpret_cast(input->Data()), - reinterpret_cast(skip->Data()), - reinterpret_cast(gamma->Data()), - (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, - epsilon_, - hidden_size, - static_cast(element_count)); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h deleted file mode 100644 index 02228bc59cedc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class SkipLayerNorm final : public RocmKernel { - public: - SkipLayerNorm(const OpKernelInfo& op_kernel_info); - Status ComputeInternal(OpKernelContext* context) const override; - - private: - float epsilon_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu deleted file mode 100644 index 8387c49a3310b..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu +++ /dev/null @@ -1,86 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on skipLayerNorm plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// Modifications: Add SkipLayerNormKernelVec to -// leverage vectorized load/write. -// and templatize ComputeSkipLayerNorm for different -// data types. -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" - -#include - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" -#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, - const T* skip, const V* gamma, const V* beta, const T* bias, float epsilon, int ld, int element_count) { - // this must be true because element_count is the total size of the tensor - assert(element_count % ld == 0); - - SkipLayerNormParams params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, - gamma, beta, bias, epsilon, ld, element_count); - - if (tuning_ctx->IsTunableOpEnabled()) { - static SkipLayerNormTunableOp op; - return op(¶ms); - } - - return SkipLayerNormStaticSelection(¶ms); -} - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, - const float* skip, const float* gamma, const float* beta, - const float* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, - const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, - const float* skip, const float* gamma, const float* beta, - const float* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, - const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, int ld, - int element_count); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h deleted file mode 100644 index 5e2a92447d2f5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning, - Stream* stream, - V* output, // output tensor - T* skip_input_bias_add_output, // optional output tensor - const T* input, // input tensor - const T* skip, // skip tensor - const V* gamma, // Layer normalization gamma tensor - const V* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int element_count // number of elements in input tensor -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h deleted file mode 100644 index fcfbc8969e498..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "contrib_ops/rocm/bert/layer_norm.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -T maybe2half(float x); - -template <> -float maybe2half(float x) { - return x; -} - -template <> -half maybe2half(float x) { - return __float2half_rn(x); -} - -template -__global__ void SkipLayerNormKernel( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias, - const U epsilon, V* output, T* skip_input_bias_add_output) { - const U reverse_ld = U(1.f / ld); - const int offset = blockIdx.x * ld; - - KeyValuePairSum pair_sum; - // reduce x and x^2 - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = (bias == nullptr) ? static_cast(input[idx]) + static_cast(skip[idx]) : static_cast(input[idx]) + static_cast(skip[idx]) + static_cast(bias[i]); - const U rldval = reverse_ld * val; - thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); - - if (skip_input_bias_add_output != nullptr) { - skip_input_bias_add_output[idx] = static_cast(val); - } - - output[idx] = static_cast(val); - } - - if constexpr (Simplified) { - SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); - return; - } - - LayerNorm(thread_data, ld, offset, beta, gamma, epsilon, output); -} - -// Vectorized kernel -template -__global__ void SkipLayerNormKernelVec( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, - const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput) { - const U reverse_ld = U(1.f / ld); - const int offset = blockIdx.x * ld; - - KeyValuePairSum pair_sum; - // reduce x and x^2 - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - using VecT = aligned_vector; - using VecV = aligned_vector; - if (threadIdx.x * ILP < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - - const VecT input_v = *reinterpret_cast(input + idx); - const VecT skip_v = *reinterpret_cast(skip + idx); - const VecT bias_v = hasBias ? *reinterpret_cast(bias + i) : VecT(); - VecT skip_input_bias_add_output_v, output_v; - -#pragma unroll - for (int k = 0; k < ILP; k++) { - const U val = hasBias ? static_cast(input_v.val[k]) + static_cast(skip_v.val[k]) + static_cast(bias_v.val[k]) : static_cast(input_v.val[k]) + static_cast(skip_v.val[k]); - const U rldval = reverse_ld * val; - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v.val[k] = static_cast(val); - } - thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); - output_v.val[k] = static_cast(val); - } - - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; - } - - *(reinterpret_cast(output + idx)) = output_v; - } - } - - if constexpr (Simplified) { - SimplifiedLayerNormVec(thread_data.value, ld, offset, gamma, epsilon, output); - return; - } - - LayerNormVec(thread_data, ld, offset, beta, gamma, epsilon, output); -} - -// Vectorized kernel -template -__global__ void SkipLayerNormKernelSmall( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, - const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput) { - const U rld = U(1.f / ld); - const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld - - using VecT = aligned_vector; - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - VecT input_v; - if (ILP * threadIdx.x < ld) { - input_v = *reinterpret_cast(input + idx); - const VecT skip_v = *reinterpret_cast(skip + idx); - const VecT bias_v = hasBias ? *reinterpret_cast(bias + threadIdx.x * ILP) : VecT(); - VecT skip_input_bias_add_output_v; - - U rldval_sum = U(0.f); - U rldvalsq_sum = U(0.f); -#pragma unroll - for (int i = 0; i < ILP; i++) { - const U val = hasBias ? static_cast(input_v.val[i]) + static_cast(skip_v.val[i]) + static_cast(bias_v.val[i]) : static_cast(input_v.val[i]) + static_cast(skip_v.val[i]); - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v.val[i] = static_cast(val); - } - - const U rldval = rld * val; - rldval_sum += rldval; - rldvalsq_sum += rldval * val; - input_v.val[i] = static_cast(val); - } - - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; - } - - thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); - } - - if constexpr (Simplified) { - SimplifiedLayerNormSmall(input_v.val, thread_data.value, ld, idx, gamma, epsilon, output); - return; - } - - LayerNormSmall(input_v.val, thread_data, ld, idx, beta, gamma, epsilon, output); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h deleted file mode 100644 index 0391704ce1c56..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::CeilDiv; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -struct SkipLayerNormParams : OpParams { - SkipLayerNormParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, - const T* skip, const V* gamma, const V* beta, - const T* bias, float epsilon, int ld, int element_count) - : OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {} - - std::string Signature() const override { - std::string sig = std::to_string(ld) + "_" + std::to_string(element_count); - return sig; - } - - V* output; - T* skip_input_bias_add_output; - const T* input; - const T* skip; - const V* gamma; - const V* beta; - const T* bias; - float epsilon; - int ld; - int element_count; -}; - -template -Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { - // Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels, - // which could offer better performance in some combinations of ThreadsPerBlock and VecSize. - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->ld <= 8192 && params->ld % VecSize == 0 && - params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))); - SkipLayerNormKernelSmall<<element_count, params->ld)), - dim3(ThreadsPerBlock), - 0, params->StreamHandle()>>>( - params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, - (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); - return HIP_CALL(hipGetLastError()); -} - -template -Status SkipLayerNormRegularOp(const SkipLayerNormParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->ld > 0 && params->ld % VecSize == 0 && - (params->ld >= ThreadsPerBlock * VecSize || - (params->ld < GPU_WARP_SIZE && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))))); - SkipLayerNormKernelVec<<element_count, params->ld)), - dim3(ThreadsPerBlock), - 0, params->StreamHandle()>>>( - params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, - (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); - return HIP_CALL(hipGetLastError()); -} - -template -Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { - bool hasBias = (params->bias == nullptr) ? false : true; - bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true; - const int grid_size = params->element_count / params->ld; - const int block_size = 256; - -#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \ - if (params->ld <= ELEMENTS) { \ - SkipLayerNormKernelSmall<<StreamHandle()>>>( \ - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \ - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, \ - hasBias, hasSkipInputBiasAdditionOutput); \ - break; \ - } - if (0 == (params->ld % 4)) { - do { - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 32, 2) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 32, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 96, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(768, 192, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(1024, 256, 4) - - SkipLayerNormKernel<<StreamHandle()>>>( - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); - } while (0); - } else { - do { - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 64, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 128, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 384, 1) - - SkipLayerNormKernel<<StreamHandle()>>>( - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); - } while (0); - } - return HIP_CALL(hipPeekAtLastError()); -} // namespace rocm - -#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); - -#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 512) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 576) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 640) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 704) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 768) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 832) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 1024) - -template -class SkipLayerNormTunableOp : public TunableOp> { - public: - SkipLayerNormTunableOp() { - this->RegisterOp(SkipLayerNormStaticSelection); - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmallOp) - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegularOp) - - // NOTE: the 1st kernel is SkipLayerNorm Original implementation. - this->SetDefaultId(0); - } -}; - -#undef ADD_OP_FOR_ALL_VEC_SIZE -#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc b/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc deleted file mode 100644 index 6ae8d1202d462..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#include -#include "core/providers/shared_library/provider_api.h" // Include this otherwise Windows build complains Env::Default() missing -#include "core/platform/env_var_utils.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -// The environment variable is for testing purpose only, and it might be removed in the future. -// If you need some option in production, please file a feature request. -constexpr const char* kTransformerOptions = "ORT_TRANSFORMER_OPTIONS"; - -// Initialize the singleton instance -TransformerOptions TransformerOptions::instance; - -const TransformerOptions* TransformerOptions::GetInstance() { - if (!instance.initialized_) { - // We do not use critical section here since it is fine to initialize multiple times by different threads. - int value = ParseEnvironmentVariableWithDefault(kTransformerOptions, 0); - instance.Initialize(value); - - if (value > 0) - std::cout << "ORT_TRANSFORMER_OPTIONS: IsPrecisionMode=" << instance.IsPrecisionMode() - << ",DisablePersistentSoftmax=" << instance.DisablePersistentSoftmax() - << ",DisableHalf2=" << instance.DisableHalf2() - << std::endl; - } - - return &instance; -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.h b/onnxruntime/contrib_ops/rocm/bert/transformer_common.h deleted file mode 100644 index 6816b5b9d07ec..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/transformer_common.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -class TransformerOptions { - public: - static const TransformerOptions* GetInstance(); - - bool IsPrecisionMode() const { return is_precision_mode_; } - - bool DisablePersistentSoftmax() const { return disable_persistent_softmax_; } - - bool DisableHalf2() const { return disable_half2_; } - - void Initialize(int value) { - is_precision_mode_ = (value & 0x01) > 0; - disable_persistent_softmax_ = (value & 0x02) > 0; - disable_half2_ = (value & 0x04) > 0; - initialized_ = true; - } - - private: - // Default is false. If the mode is on, prefer precision than speed. - bool is_precision_mode_{false}; - - // Disable persistent softmax. - bool disable_persistent_softmax_{false}; - - // Disable half2 kernel. - bool disable_half2_{false}; - - bool initialized_{false}; - - static TransformerOptions instance; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh deleted file mode 100644 index d0a0d09fcbae3..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#ifdef USE_COMPOSABLE_KERNEL -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#endif // USE_COMPOSABLE_KERNEL - -#include "contrib_ops/rocm/diffusion/group_norm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#ifdef USE_COMPOSABLE_KERNEL - -using onnxruntime::rocm::CKDataTypeAdaptor; - -// The SiLU function is a special case of Swish function, -// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: -// SiLU(x) = x * sigmoid(x) -// Swish(x) = x * sigmoid(bx) -// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. -using Silu = ck::tensor_operation::element_wise::Swish; -using Pass = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -template -auto GetCKGroupNormNHWCTypeStringAndOps() { - using XDataType = typename CKDataTypeAdaptor::type; - using YDataType = typename CKDataTypeAdaptor::type; - using SaveMeanInvStdDataType = typename CKDataTypeAdaptor::type; - using GammaDataType = float; - using BetaDataType = float; - - using Activation = std::conditional_t; - - std::vector>>> ret; - for (auto&& impl : internal::GetDeviceGroupNormInstances()) { - std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; - auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; - auto invoker = impl->MakeInvokerPointer(); - - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), - "Input skip or bias is not supported by composable kernel."); - if constexpr (WithSilu) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->use_silu, "Silu version only support groupnorm with silu"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->use_silu, "Pass version only support groupnorm without silu"); - } - std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, - params->c, params->channels_per_group, 1}; - std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; - std::vector reduce_dims{1, 2, 4}; - - auto activation = Activation{}; - - auto arg = impl->MakeArgumentPointer(in_lengths, // lengths - in_out_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - in_out_strides, // yStrides - {0, 0}, // saveMeanStrides - {0, 0}, // saveInvStdStrides - reduce_dims, // reduceDims - params->epsilon, - params->src, - params->gamma, - params->beta, - params->dst, - nullptr, - nullptr, - activation); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_group_norm_op))); - } - return ret; -} -#endif // USE_COMPOSABLE_KERNEL - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh deleted file mode 100644 index 68f7d47282845..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ /dev/null @@ -1,130 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#ifdef USE_COMPOSABLE_KERNEL -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" -#include "ck/utility/data_type.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using F16 = ck::half_t; -using F32 = float; - -using Silu = ck::tensor_operation::element_wise::Swish; -using Pass = ck::tensor_operation::element_wise::PassThrough; - -using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface -using ck::tensor_operation::device::DeviceNormalizationFwdImpl; // the implementation - -// See https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1fefd82ed8/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp - -template -using device_normalization_f32_instances = std::tuple< - // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl - // clang-format on - >; - -template -using device_normalization_f16_instances = - // clang-format off - std::tuple < - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl - // clang-format on - >; - -// Use this function to get implementation -template -std::vector>> -GetDeviceGroupNormInstances() { - return {}; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Silu, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Pass, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Silu, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Pass, 5, 3>(); - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu deleted file mode 100644 index ad191314e5e4c..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f16_instances{}); - - return instances; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f16_instances{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu deleted file mode 100644 index ceb53ed442abc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f32_instances{}); - - return instances; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f32_instances{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h deleted file mode 100644 index 7cff640db2f34..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { - GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, - onnxruntime::Stream* ort_stream, - T* output, - T* add_out, - const T* input, - const T* skip, - const T* bias, - const float* gamma, - const float* beta, - float* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_silu, - bool broadcast_skip, - int channels_per_block) - : OpParams(tuning_ctx, ort_stream), - GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, - num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} - - std::string Signature() const override { - std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; - std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; - std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; - std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; - std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + - std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + - skip_suffix + broadcast_suffix + bias_suffix; - return sig; - } -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu deleted file mode 100644 index 142aaf14e8d2d..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// The ROCM kernel is hipified from CUDA kernel. -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" - -#include -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_tunable_op.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchGroupNormKernel( - RocmTuningContext* tuning_ctx, - Stream* ort_stream, - T* output, - T* add_out, - const T* input, - const T* skip, - const T* bias, - const float* gamma, - const float* beta, - void* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_silu, - bool broadcast_skip, - int channels_per_block) { - GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, - reinterpret_cast(workspace), epsilon, batch_size, num_channels, - height, width, num_groups, use_silu, broadcast_skip, channels_per_block); - - if (params.channels_per_block % params.channels_per_group != 0 || - params.channels_per_block > kMaxSize || - (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "GroupNorm in ROCM does not support the input: n=", batch_size, - " h=", height, - " w=", width, - " c=", num_channels, - " groups=", num_groups); - } - - HIP_RETURN_IF_ERROR(hipMemsetAsync( - params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); - - if (tuning_ctx->IsTunableOpEnabled()) { - static GroupNormNHWCTunableOp op; - return op(¶ms); - } - - return GroupNormNHWCStaticSelection(¶ms); -} - -template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, - half* add_out, const half* input, const half* skip, const half* bias, - const float* gamma, const float* beta, void* workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block); - -template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, - float* add_out, const float* input, const float* skip, const float* bias, - const float* gamma, const float* beta, void* workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh deleted file mode 100644 index c6ca16bfdfc80..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "core/providers/rocm/triton_kernel.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#ifdef USE_TRITON_KERNEL - -namespace { - -template -std::string GetGroupNormTritonGroupName() { - std::string ret = "GroupNormTriton_"; - std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; - ret += silu_suffix; - ret += GetDataTypeName(); - return ret; -} - -} // namespace - -template -auto GetTritonGroupNormNHWCTypeStringAndOps() { - std::vector>>> ret; - auto group_name = GetGroupNormTritonGroupName(); - auto* kernel_list = GetOrtTritonKernelByGroup(group_name); - if (kernel_list == nullptr) { - return ret; - } - - for (auto i : *kernel_list) { - // Check params match - auto* metadata = GetOrtTritonKernelMetadata(i); - auto block_size = metadata->constants.at("BLOCK_SIZE"); - auto hw_size = metadata->constants.at("HW_SIZE"); - auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", - params->channels_per_group, ")."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); - if constexpr (WithSilu) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); - } - // Construct args for launch kernel - struct { - const void* src; - const void* skip; - const void* bias; - void* out; - void* add_out; - const void* gamma; - const void* beta; - int hw; - int c; - int c_per_group; - float eps; - bool has_skip; - bool has_bias; - bool broadcast_skip; - } args = { - (const void*)params->src, - (const void*)params->skip, - (const void*)params->bias, - (void*)params->dst, - (void*)params->skip_workspace, - (const void*)params->gamma, - (const void*)params->beta, - params->hw, - params->c, - params->channels_per_group, - params->epsilon, - params->skip != nullptr, - params->bias != nullptr, - params->broadcast_skip, - }; - - // Grid dim is (batch_count, groups, 1) - return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); - }; - ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); - } - return ret; -} - -#endif // USE_TRITON_KERNEL - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py deleted file mode 100644 index 5ba96ebc117f0..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ /dev/null @@ -1,135 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from itertools import product - -import triton -import triton.language as tl - - -@triton.jit -def group_norm_kernel( - input_ptr, - skip_ptr, - bias_ptr, - output_ptr, - add_out_ptr, - gamma_ptr, - beta_ptr, - img_size, - c, - c_per_group, - eps, - has_skip, - has_bias, - broadcast_skip, - BLOCK_SIZE: tl.constexpr, - HW_SIZE: tl.constexpr, - ACTIVATION_SILU: tl.constexpr, -): - row_x = tl.program_id(0) - row_y = tl.program_id(1) - stride = img_size * c - input_ptr += row_x * stride + row_y * c_per_group - output_ptr += row_x * stride + row_y * c_per_group - gamma_ptr += row_y * c_per_group - beta_ptr += row_y * c_per_group - - cols = tl.arange(0, BLOCK_SIZE) - hw = tl.arange(0, HW_SIZE) - offsets = hw[:, None] * c + cols[None, :] - mask = (cols < c_per_group)[None, :] - - bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - if has_skip: - add_out_ptr += row_x * stride + row_y * c_per_group - if broadcast_skip: - broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group - bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) - else: - skip_ptr += row_x * stride + row_y * c_per_group - if has_bias: - bias_ptr += row_y * c_per_group - bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) - - # Calculate mean and variance - _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) - _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) - for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c - a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - if has_skip and not broadcast_skip: - s_ptr = skip_ptr + i * HW_SIZE * c - s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - a += s - if has_bias or broadcast_skip: - a += bias - _sum += a - _square_sum += a * a - if has_skip: - add_y_ptr = add_out_ptr + i * HW_SIZE * c - tl.store(add_y_ptr + offsets, a, mask=mask) - - # Set axis=None (or leave it unspecified) to reduce all axes. - # TODO: In older Triton we have to reduce an axis at a time, but in our case - # for some configs it may have some issue when reducing sequentially along the axes. - group_mean = tl.sum(_sum, axis=None) / (img_size * c_per_group) - group_var = tl.sum(_square_sum, axis=None) / (img_size * c_per_group) - group_mean * group_mean - - rstd = 1 / tl.sqrt(group_var + eps) - - # Normalize and apply linear transformation - gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) - beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) - for i in range(tl.cdiv(img_size, HW_SIZE)): - y_ptr = output_ptr + i * HW_SIZE * c - if has_skip: - add_y_ptr = add_out_ptr + i * HW_SIZE * c - x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - else: - x_ptr = input_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - x_hat = (x - group_mean) * rstd - y = x_hat * gamma + beta - if ACTIVATION_SILU: - y *= tl.sigmoid(y) - tl.store(y_ptr + offsets, y, mask=mask) - - -# We can have more combinations of blocks and hw_sizes, e.g., -# blocks = [16, 32, 64, 128, 256, 512] -# hw_sizes = [8, 16, 32, 64, 128, 256, 512] -# but this will result in too many functions and slow down the compilation. -with_silu = [True, False] -dtypes = ["fp32", "fp16"] -blocks = [16, 32, 64, 128] -hw_sizes = [8, 16, 32, 64, 128, 256] -warps = [1, 2, 4, 8, 16] -name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" -group_pattern = "GroupNormTriton_{}_{}" - - -def get_function_table(): - func_table = [] - - for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): - silu_suffix = "Silu" if silu else "Pass" - name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) - group = group_pattern.format(silu_suffix, dtype) - sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) - kwargs = { - "num_warps": warp, - "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, - } - func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} - func_table.append(func_desc) - return func_table - - -if __name__ == "__main__": - func_table = get_function_table() - for func_desc in func_table: - print(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h deleted file mode 100644 index e6831f764b418..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh" -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh" -#include "contrib_ops/rocm/diffusion/group_norm_triton.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using onnxruntime::rocm::GPU_WARP_SIZE; - -template -void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { - dim3 grid; - - // The number of blocks to compute all the channels. - grid.x = DivUp(params->c, params->channels_per_block); - // The number of blocks to compute all the activations in a given instance. - grid.y = DivUp(params->hw, params->hw_per_block); - // The number of instances. - grid.z = params->n; - -#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ - GroupNormNHWCSumKernel \ - <<StreamHandle()>>>( \ - params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ - params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ - params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ - break; - - // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. - switch (params->threads_per_block) { - case 256: - LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) - case 192: - LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) - case 160: - LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) - case 128: - LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) - case 64: - LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) - default: - ORT_NOT_IMPLEMENTED("Not implemented"); - } -} - -template -Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { - dim3 grid; - grid.x = DivUp(params->c, params->channels_per_block); - grid.y = DivUp(params->hw, params->hw_per_block); - grid.z = params->n; - - GroupNormNHWCSumKernel - <<StreamHandle()>>>( - params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, - params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, - params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); - return HIP_CALL(hipGetLastError()); -} - -template -void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { - dim3 grid; - - // The number of blocks to compute all the channels. - grid.x = DivUp(params->c, params->channels_per_block); - // The number of blocks to compute all the activations in a given instance. - grid.y = DivUp(params->hw, params->hw_per_block); - // The number of instances. - grid.z = params->n; - -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - GroupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ - params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ - params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ - params->hw, params->hw_per_block, params->use_silu); \ - break; - - // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. - switch (params->threads_per_block) { - case 256: - LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) - case 192: - LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) - case 160: - LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) - case 128: - LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) - case 64: - LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) - default: - ORT_NOT_IMPLEMENTED("Not implemented"); - } -} - -template -Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { - dim3 grid; - grid.x = DivUp(params->c, params->channels_per_block); - grid.y = DivUp(params->hw, params->hw_per_block); - grid.z = params->n; - - GroupNormNHWCScaleKernel - <<StreamHandle()>>>( - params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, - params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, - params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, - params->use_silu); - return HIP_CALL(hipGetLastError()); -} - -template -class GroupNormNHWCOp { - public: - Status operator()(const GroupNormNHWCTunableParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, - 0, - GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), - params->StreamHandle())); - auto status = GroupNormNHWCSumOp(params); - ORT_RETURN_IF_ERROR(status); - HIP_RETURN_IF_ERROR(hipGetLastError()); - status = GroupNormNHWCScaleOp(params); - ORT_RETURN_IF_ERROR(status); - HIP_RETURN_IF_ERROR(hipGetLastError()); - return Status::OK(); - } - - Status IsSupported(const GroupNormNHWCTunableParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), - "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, - ") isn't divisible by the number of vector size: ", VecSize); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && - params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), - "Configuration: Threads (", ThreadsPerBlock, "), vector size (", - VecSize, ") is redundant for the number of channels per group: ", - params->channels_per_block); - - return Status::OK(); - } -}; - -template -Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, - 0, - GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), - params->StreamHandle())); - GroupNormNHWCSum(params); - HIP_RETURN_IF_ERROR(hipGetLastError()); - GroupNormNHWCScale(params); - HIP_RETURN_IF_ERROR(hipGetLastError()); - return Status::OK(); -} - -#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ - this->RegisterOp(name{}); \ - this->RegisterOp(name{}); \ - this->RegisterOp(name{}); - -#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 320) - -template -class GroupNormNHWCTunableOp : public TunableOp> { - public: - GroupNormNHWCTunableOp() { - this->RegisterOp(GroupNormNHWCStaticSelection); - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif // USE_COMPOSABLE_KERNEL - -#ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } -}; - -#undef ADD_OP_FOR_ALL_VEC_SIZE -#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc deleted file mode 100644 index 35427a02c631d..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/nn/conv.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - NhwcConv, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/fused_conv.cc b/onnxruntime/contrib_ops/rocm/fused_conv.cc deleted file mode 100644 index 4f3be98d97f80..0000000000000 --- a/onnxruntime/contrib_ops/rocm/fused_conv.cc +++ /dev/null @@ -1,439 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include "core/common/status.h" -#include "core/providers/rocm/nn/conv.h" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace { - -// Copied from hipDNN/library/src/hcc_detail/hipdnn_miopen.cpp -miopenStatus_t _miopenAddTensor( - miopenHandle_t handle, - const void* alpha, - const miopenTensorDescriptor_t aDesc, - const void* A, - const void* beta, - const miopenTensorDescriptor_t cDesc, - void* C, - const void* zero_scalar) { - const miopenTensorOp_t tensorOp = miopenTensorOpAdd; - // Using miopenOpTensor to implement Add operator. - // opnd2 = Add ( 0.0 * opnd0, alpha * opnd1 ) + beta * opnd2 - return miopenOpTensor(handle, tensorOp, - zero_scalar, cDesc, C, - alpha, aDesc, A, - beta, cDesc, C); -} - -} // namespace - -template -struct FNVHash { - uint32_t GetValue() const { return value_; } - - void Hash(const void* in_ptr, size_t nbytes) { - auto ptr = reinterpret_cast(in_ptr); - for (size_t i = 0; i < nbytes; ++i) { - value_ ^= ptr[i]; - value_ *= PRIME; - } - } - - template ::value, size_t>::type = 0> - FNVHash& operator<<(const T& pod) { - Hash(&pod, sizeof(pod)); - return *this; - } - - template - FNVHash& operator<<(const std::vector& pod_array) { - for (const auto& pod : pod_array) { - (*this) << pod; - } - return *this; - } - - void HashTensor(miopenTensorDescriptor_t tdesc) { - int size = 0; - miopenGetTensorDescriptorSize(tdesc, &size); - (*this) << size; - std::vector dims(size); - std::vector strides(size); - miopenDataType_t dtype; - miopenGetTensorDescriptor(tdesc, &dtype, dims.data(), strides.data()); - (*this) << dtype; - (*this) << dims; - (*this) << strides; - } - - void HashConvolutionDescriptor(miopenConvolutionDescriptor_t cdesc) { - int spatial_dim = 1; -#if ROCM_VERSION >= 50500 - MIOPEN_CALL(miopenGetConvolutionSpatialDim(cdesc, &spatial_dim)); - std::vector pads{spatial_dim}; - std::vector strides{spatial_dim}; - std::vector dilations{spatial_dim}; - miopenConvolutionMode_t mode; - MIOPEN_CALL(miopenGetConvolutionNdDescriptor(cdesc, spatial_dim, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)); -#else - // Previous versions of MIOpen doesn't provide API to probe the dimension of a - // miopenConvolutionDescriptor_t, so we have to guess. - // This algorithm is based on a specific behavior of miopenGetConvolutionNdDescriptor, - // which fails when requestedSpatialDim > the convolution's spatial dimension - constexpr const int kMaxSpatialDim = 5; - std::vector pads{kMaxSpatialDim}; - std::vector strides{kMaxSpatialDim}; - std::vector dilations{kMaxSpatialDim}; - miopenConvolutionMode_t mode; - bool spatial_dim_guessed = false; - for (int i = 0; i < kMaxSpatialDim; i++) { - if (miopenStatusSuccess == miopenGetConvolutionNdDescriptor( - cdesc, i, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)) { - spatial_dim_guessed = true; - break; - } - } - ORT_ENFORCE(spatial_dim_guessed, "Failed to guess the actual spatial dimension"); - // Remove the extra dimension - pads.resize(spatial_dim); - strides.resize(spatial_dim); - dilations.resize(spatial_dim); -#endif - (*this) << spatial_dim; - (*this) << pads; - (*this) << strides; - (*this) << dilations; - (*this) << mode; - } - - private: - uint32_t value_ = BASIS; -}; - -template -class FusedConv : public onnxruntime::rocm::Conv { - public: - using Base = onnxruntime::rocm::Conv; - FusedConv(const OpKernelInfo& info) : onnxruntime::rocm::Conv(info) { - std::string activation; - ORT_THROW_IF_ERROR(info.GetAttr("activation", &activation)); - ORT_THROW_IF_ERROR(MapMode(activation)); - MIOPEN_CALL_THROW(miopenCreateActivationDescriptor(&activation_desc_)); - MIOPEN_CALL_THROW(miopenSetActivationDescriptor(activation_desc_, activation_mode_, 0.0, 0.0, 0.0)); - MIOPEN_CALL_THROW(miopenCreateOperatorArgs(&fusion_args_)); - } - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(FusedConv); - - ~FusedConv() { - if (activation_desc_) { - MIOPEN_CALL_THROW(miopenDestroyActivationDescriptor(activation_desc_)); - activation_desc_ = nullptr; - } - - if (fusion_args_) { - miopenDestroyOperatorArgs(fusion_args_); - } - } - - Status ComputeInternal(OpKernelContext* context) const override { - std::lock_guard lock(Base::s_.mutex); - - ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); - if (Base::s_.Y->Shape().Size() == 0) { - return Status::OK(); - } - - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - auto factory = [this](FusedConvFusionData& fusion) { - return this->DoCreateFusionDesc(this->Node().Name(), fusion); - }; - auto& cached_item = plan_cache_.FindOrCreateFusionPlanCache(Hash(), - factory); - bool should_try_fusion_api = cached_item.Validate(this->GetMiopenHandle(context)); - - typedef typename onnxruntime::rocm::ToHipType::MappedType HipT; - const auto alpha = onnxruntime::rocm::Consts::One; - const auto beta = onnxruntime::rocm::Consts::Zero; - IAllocatorUniquePtr workspace = Base::GetWorkSpace(context->GetComputeStream()); - miopenStatus_t fusion_status = miopenStatusNotInitialized; - - if (should_try_fusion_api) { - auto& fusion_info = *cached_item.fusion; - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsConvForward(fusion_args_, - fusion_info.conv_op, - &alpha, - &beta, - Base::s_.w_data)); - if (has_z) { - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, - fusion_info.bias_z_op, - &alpha, - &beta, - Base::s_.z_data)); - } - if (has_b) { - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, - fusion_info.bias_b_op, - &alpha, - &beta, - Base::s_.b_data)); - } - if (activation_desc_) { - const float relu_notused = 0.0; - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsActivForward(fusion_args_, - fusion_info.act_op, - &alpha, - &beta, - relu_notused, - relu_notused, - relu_notused)); - } - fusion_status = miopenExecuteFusionPlan(this->GetMiopenHandle(context), - fusion_info.plan, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.y_tensor, - Base::s_.y_data, - fusion_args_); - } - if (miopenStatusSuccess != fusion_status) { - MIOPEN_RETURN_IF_ERROR(miopenConvolutionForward(this->GetMiopenHandle(context), - &alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.fwd_algo, - &beta, - Base::s_.y_tensor, - Base::s_.y_data, - workspace.get(), - Base::s_.workspace_bytes)); - if (has_b) { - MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), - &alpha, Base::s_.b_tensor, Base::s_.b_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data, - &beta)); - } - if (has_z) { - MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), - &alpha, Base::s_.z_tensor, Base::s_.z_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data, - &beta)); - } - MIOPEN_RETURN_IF_ERROR(miopenActivationForward(this->GetMiopenHandle(context), - activation_desc_, - &alpha, - Base::s_.y_tensor, - Base::s_.y_data, - &beta, - Base::s_.y_tensor, - Base::s_.y_data)); - } - if (Base::s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(onnxruntime::rocm::SliceOutUnwantedOutputSection( - this->Stream(context), - Base::s_.y_data, - Base::s_.y_dims_with_adjusted_pads, - Base::s_.Y->MutableDataRaw(), - Base::s_.y_dims.GetDims(), - Base::s_.slice_starts, - Base::s_.slice_ends, - Base::s_.slice_axes, - Base::s_.element_size)); - } - return Status::OK(); - } - - private: - Status MapMode(const std::string& activaton_mode) { - if (activaton_mode == "Relu") { - activation_mode_ = miopenActivationMode_t::miopenActivationRELU; - } else { - return ORT_MAKE_STATUS( - StatusCategory::ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, - "unsupported conv activation mode \"", activaton_mode, "\""); - } - return Status::OK(); - } - miopenActivationMode_t activation_mode_; - miopenActivationDescriptor_t activation_desc_ = nullptr; - - miopenOperatorArgs_t fusion_args_ = nullptr; - - // MIOpen Fusion API - // TODO: create one fusion descriptor shared by multiple FusedConv - // objects - // - // Considerations: - // How to determine two FusedConv objects may share the same fusion - // descriptor? Hashing x_tensor,conv_desc, etc.? - struct FusedConvFusionData { - miopenFusionPlanDescriptor_t plan = nullptr; - miopenFusionOpDescriptor_t conv_op = nullptr; - miopenFusionOpDescriptor_t bias_b_op = nullptr; - miopenFusionOpDescriptor_t bias_z_op = nullptr; - miopenFusionOpDescriptor_t act_op = nullptr; - - // TODO: There is a potential problem. miopenHandle_t may be destroyed and - // re-created later, sharing the same address. Currently there is any way - // to detect it? - mutable std::unordered_set compiled_on; - - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FusedConvFusionData); - - FusedConvFusionData() {} - ~FusedConvFusionData() { - if (plan) { - miopenDestroyFusionPlan(plan); - } - } - }; - - struct FusionPlanCacheItem { - std::unique_ptr fusion; - Status creation_result; - // TODO: Add a timestamp for eviction - // std::chrono::time_point last_access; - - FusionPlanCacheItem() {} - - miopenStatus_t CompileOnHandle(miopenHandle_t handle) const { - if (!fusion->plan) { - return miopenStatusNotInitialized; - } - auto iter = fusion->compiled_on.find(handle); - if (iter != fusion->compiled_on.end()) { - return miopenStatusSuccess; - } - auto ret = miopenCompileFusionPlan(handle, fusion->plan); - if (miopenStatusSuccess == ret) { - fusion->compiled_on.insert(handle); - } else { - return ret; - } - return miopenStatusSuccess; - } - - bool Validate(miopenHandle_t handle) const { - if (Status::OK() != creation_result) { - return false; - } - if (!fusion || !fusion->plan) { - return false; - } - auto compiling_status = CompileOnHandle(handle); - if (miopenStatusSuccess != compiling_status) { - return false; - } - - return true; - } - }; - - struct FusionPlanCache { - mutable std::mutex mutex; - using HashKey = uint32_t; - std::unordered_map cache_directory_; - - FusionPlanCache() { - } - - FusionPlanCacheItem& FindOrCreateFusionPlanCache(HashKey key, - std::function factory) { - std::lock_guard lock(mutex); - auto iter = cache_directory_.find(key); - if (iter == cache_directory_.end()) { - cache_directory_[key].fusion = std::make_unique(); - cache_directory_[key].creation_result = factory(*cache_directory_[key].fusion); - if (Status::OK() != cache_directory_[key].creation_result) { - cache_directory_[key].fusion.reset(); - } - } - return cache_directory_[key]; - } - }; - - static FusionPlanCache plan_cache_; - - Status DoCreateFusionDesc(const std::string& node_name, FusedConvFusionData& fusion) const { - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - MIOPEN_RETURN_IF_ERROR(miopenCreateFusionPlan(&fusion.plan, - miopenVerticalFusion, - Base::s_.x_tensor)); - auto status = miopenCreateOpConvForward(fusion.plan, &fusion.conv_op, Base::s_.conv_desc, Base::s_.w_desc); - if (status == miopenStatusUnsupportedOp) { - auto msg = MakeString("MIOpen does not support the conv fusion for node \"", - node_name, "\", fallback to unfused implementation."); - LOGS_DEFAULT(WARNING) << msg; - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, msg); - } - MIOPEN_RETURN_IF_ERROR(status); - - if (has_z) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, - &fusion.bias_z_op, - Base::s_.z_tensor)); - } - if (has_b) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, - &fusion.bias_b_op, - Base::s_.b_tensor)); - } - if (activation_desc_) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpActivationForward(fusion.plan, - &fusion.act_op, - activation_mode_)); - } - return Status::OK(); - } - - uint32_t Hash() const { - FNVHash hash; - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - hash.HashTensor(Base::s_.x_tensor); - hash.HashConvolutionDescriptor(Base::s_.conv_desc); - hash.HashTensor(Base::s_.w_desc); - if (has_z) { - hash.HashTensor(Base::s_.z_tensor); - } - if (has_b) { - hash.HashTensor(Base::s_.b_tensor); - } - if (activation_desc_) { - hash << static_cast(activation_mode_); - } - return hash.GetValue(); - } -}; - -template -typename FusedConv::FusionPlanCache FusedConv::plan_cache_; - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - FusedConv, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - FusedConv); - -REGISTER_KERNEL_TYPED(float); -REGISTER_KERNEL_TYPED(MLFloat16); -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu deleted file mode 100644 index 3539f32252944..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/common.h" -#include "core/common/float16.h" -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; -using namespace onnxruntime::rocm::tunable::blas; - -class GemmFloat8 final : public RocmKernel { - public: - GemmFloat8(const OpKernelInfo& info) : RocmKernel(info) { - transA_ = info.GetAttrOrDefault("transA", 0); - transB_ = info.GetAttrOrDefault("transB", 0); - dtype_ = info.GetAttrOrDefault("dtype", onnx::TensorProto_DataType_FLOAT16); - alpha_ = info.GetAttrOrDefault("alpha", 1); - beta_ = info.GetAttrOrDefault("beta", 0); - } - Status ComputeInternal(OpKernelContext* ctx) const override; - - private: -#if !defined(DISABLE_FLOAT8_TYPES) - template - Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const; - template - Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const; - - template - [[nodiscard]] inline auto* GetOp() const { - using OpT = GemmFloat8TunableOp; - if (tunable_op_) { - return static_cast(tunable_op_.get()); - } - - auto create = std::make_unique(); // avoid new - tunable_op_ = std::shared_ptr(create.release(), [](void* ptr) { - auto release = std::unique_ptr(); // avoid delete - release.reset(static_cast(ptr)); - }); - - return static_cast(tunable_op_.get()); - } -#endif - - float alpha_; - float beta_; - bool transA_; - bool transB_; - int64_t dtype_; - - // fully type erased - mutable std::shared_ptr tunable_op_; -}; - -Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { -#if defined(DISABLE_FLOAT8_TYPES) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DISABLE_FLOAT8_TYPES"); -#else - const Tensor* A = ctx->Input(0); - const Tensor* B = ctx->Input(1); - const Tensor* C = ctx->Input(2); // bias - const Tensor* scale_a = ctx->Input(3); - const Tensor* scale_b = ctx->Input(4); - const Tensor* scale_y = ctx->Input(5); - - auto a_shape = A->Shape(); - auto b_shape = B->Shape(); - ORT_ENFORCE(a_shape.NumDimensions() == 2); - ORT_ENFORCE(b_shape.NumDimensions() == 2); - - auto m = !transA_ ? a_shape[0] : a_shape[1]; - auto k = !transA_ ? a_shape[1] : a_shape[0]; - ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatible - auto n = !transB_ ? b_shape[1] : b_shape[0]; - - TensorShapeVector output_shape = {m, n}; - Tensor* Y = ctx->Output(0, output_shape); - - ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose"); - ORT_ENFORCE(dtype_ == onnx::TensorProto_DataType_FLOAT16, "ROCm GemmFloat8 only supports output float16"); - ORT_ENFORCE(C == nullptr, "ROCm GemmFloat8 does not support bias input"); - ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling"); - - if (A->IsDataType()) { - return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); - } else if (A->IsDataType()) { - return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); - } else if (B->IsDataType()) { - return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); - } else if (B->IsDataType()) { - return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); - } - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unhandled type combination of GemmFloat8"); -#endif -} - -#if !defined(DISABLE_FLOAT8_TYPES) -template -Status GemmFloat8::ComputeFp8Fp16Fp16( - OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const { - ORT_ENFORCE(A->IsDataType() && scale_a->IsDataType() && B->IsDataType()); - - onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; - params.tuning_ctx = GetTuningContext(); - params.stream = ctx->GetComputeStream(); - params.handle = GetHipblasHandle(ctx); - params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - - params.m = m; - params.n = n; - params.k = k; - - params.a = static_cast(A->DataRaw()); - params.lda = transA_ ? m : k; - params.scale_a = alpha_; - params.scale_a_dev = static_cast(scale_a->DataRaw()); - - params.b = static_cast(B->DataRaw()); - params.ldb = transB_ ? k : n; - params.scale_b = 1.0f; // NOTE: not used - params.scale_b_dev = nullptr; // NOTE: not used - - params.c = static_cast(C->MutableDataRaw()); - params.ldc = n; - params.scale_c = 1.0f; // NOTE: not implemented - params.scale_c_dev = nullptr; // NOTE: not implemented - - if (!transA_ && !transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && !transB_) { - ORT_NOT_IMPLEMENTED("transA is not implemented"); - } else if (!transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transB is not implemented"); - } else if (transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); -} - -template -Status GemmFloat8::ComputeFp16Fp8Fp16( - OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const { - ORT_ENFORCE(A->IsDataType() && B->IsDataType() && scale_b->IsDataType()); - - onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; - params.tuning_ctx = GetTuningContext(); - params.stream = ctx->GetComputeStream(); - params.handle = GetHipblasHandle(ctx); - params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - - params.m = m; - params.n = n; - params.k = k; - - params.a = static_cast(A->DataRaw()); - params.lda = transA_ ? m : k; - params.scale_a = 1.0f; // NOTE: not used - params.scale_a_dev = nullptr; // NOTE: not used - - params.b = static_cast(B->DataRaw()); - params.ldb = transB_ ? k : n; - params.scale_b = alpha_; - params.scale_b_dev = static_cast(scale_b->DataRaw()); - - params.c = static_cast(C->MutableDataRaw()); - params.ldc = n; - params.scale_c = 1.0f; // NOTE: not implemented - params.scale_c_dev = nullptr; // NOTE: not implemented - - if (!transA_ && !transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && !transB_) { - ORT_NOT_IMPLEMENTED("transA is not implemented"); - } else if (!transA_ && transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); -} -#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() -#else -#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() -#endif - -ONNX_OPERATOR_KERNEL_EX( - GemmFloat8, - kMSDomain, - 1, - kRocmExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) - .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) - .TypeConstraint("TR", BuildKernelDefConstraints()) - .TypeConstraint("TS", BuildKernelDefConstraints()), - GemmFloat8); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh deleted file mode 100644 index b545eb1f2a149..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#if defined(USE_COMPOSABLE_KERNEL) - -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/utility/functional3.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif - -#if !defined(DISABLE_FLOAT8_TYPES) -#include "core/common/float8.h" -#endif -#include "core/providers/rocm/tunable/gemm_common.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -constexpr bool always_false = false; - -template -struct Scale { - constexpr const static bool is_pack2_invocable = true; - constexpr const static bool is_pack4_invocable = true; - - explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {} - - template - __forceinline__ __host__ __device__ Y fast_type_convert(X x) const { - static_assert(always_false, "not implemented"); - (void)x; - } - - template <> - __forceinline__ __host__ __device__ ck::half_t fast_type_convert(ck::f8_t x) const { - // https://github.com/ROCmSoftwarePlatform/triton/blob/0cc3f8b84a16892396f6e08a04991034d67e32b1/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L220-L233 - constexpr const uint16_t mask = 0x7fff; - constexpr const uint16_t sign_mask = 0x8000; - constexpr const uint16_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x2000; - } else if constexpr (std::is_same_v) { - return 0x1c00; - } - }(); - - uint8_t x_u8 = reinterpret_cast(x); - uint16_t x_u16 = static_cast(x_u8) << 8; - uint16_t exp = (x_u16 & mask) >> 1; - uint16_t y = (x_u16 & sign_mask) | (exp + exp_compensate); - return reinterpret_cast(y); - } - - __forceinline__ __host__ __device__ void operator()(ck::half_t& y, const ck::f8_t& x) const { - float scale = scale_value_ * (*dev_scale_ptr_); - y = ck::type_convert(scale * fast_type_convert(x)); - } - - __forceinline__ __host__ __device__ void operator()(ck::half2_t& ys, const ck::f8x2_t& xs) const { - float scale = scale_value_ * (*dev_scale_ptr_); - constexpr const uint32_t mask = 0x7fff7fff; - constexpr const uint32_t sign_mask = 0x80008000; - constexpr const uint32_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x20002000; - } else if constexpr (std::is_same_v) { - return 0x1c001c00; - } - }(); - - const uchar2& x2_u8 = reinterpret_cast(xs); - uchar4 x{0, x2_u8.x, 0, x2_u8.y}; - uint32_t x_u32 = reinterpret_cast(x); - - uint32_t exp = (x_u32 & mask) >> 1; - uint32_t v = (x_u32 & sign_mask) | (exp + exp_compensate); - ys = scale * reinterpret_cast(v); - } - - __forceinline__ __host__ __device__ void operator()(ck::half4_t& ys, const ck::f8x4_t& xs) const { - float scale = scale_value_ * (*dev_scale_ptr_); - constexpr const uint32_t mask = 0x7fff7fff; - constexpr const uint32_t sign_mask = 0x80008000; - constexpr const uint32_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x20002000; - } else if constexpr (std::is_same_v) { - return 0x1c001c00; - } - }(); - - uint32_t xs_u32 = reinterpret_cast(xs); - uint32_t x_u32_0 = __byte_perm(xs_u32, 0, 0x1504); - uint32_t x_u32_1 = __byte_perm(xs_u32, 0, 0x3726); - uint32_t exp_0 = (x_u32_0 & mask) >> 1; - uint32_t exp_1 = (x_u32_1 & mask) >> 1; - uint32_t v_0 = (x_u32_0 & sign_mask) | (exp_0 + exp_compensate); - uint32_t v_1 = (x_u32_1 & sign_mask) | (exp_1 + exp_compensate); - uint64_t v = v_0 | uint64_t(v_1) << 32; - ys = scale * reinterpret_cast(v); - } - - float scale_value_; - const float* const dev_scale_ptr_; -}; -#endif - -namespace blas { - -template -struct GemmFloat8Params : tunable::OpParams { - std::string Signature() const override { - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); - } - - hipblasHandle_t handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - float scale_a{}; - const float* scale_a_dev{}; - const TA* a; - int64_t lda; - float scale_b{}; - const float* scale_b_dev{}; - const TB* b; - int64_t ldb; - TC* c; - float scale_c{}; - const float* scale_c_dev{}; - int64_t ldc; -}; - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using Nop = ck::tensor_operation::element_wise::PassThrough; - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, Nop, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, Nop, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, Nop>>>& instances); - -template -auto CreateOp(float scale, const float* dev_scale) { - if constexpr (std::is_same_v) { - return Scale(scale, dev_scale); - } else if constexpr (std::is_same_v) { - return Scale(scale, dev_scale); - } else { - return Nop{}; - } -} - -template -auto GetCKF8SplitKGemmTypeStringAndOps() { - using CKTA = typename CKDataTypeAdaptor::type; - using CKTB = typename CKDataTypeAdaptor::type; - using CKTC = typename CKDataTypeAdaptor::type; - - using CKLayoutA = typename CKBlasOpAdaptor::type; - using CKLayoutB = typename CKBlasOpAdaptor::type; - - using OpA = std::conditional_t, Scale, Nop>; - using OpB = std::conditional_t, Scale, Nop>; - using OpC = std::conditional_t, Scale, Nop>; - - using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< - CKLayoutA, CKLayoutB, Row, - CKTA, CKTB, CKTC, - OpA, OpB, OpC>; - - std::vector>>> ret; - - for (auto num_split : {1, 4, 16, 64}) { - std::vector> instances{}; - if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances); - } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances); - } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances); - } else { - static_assert(always_false, "no instances for the type combination"); - LOGS_DEFAULT(FATAL) << "no instances for the type combination"; - } - for (auto&& impl : instances) { - auto type_string = std::to_string(ret.size()) + "_" + impl->GetTypeString() + "_SplitK" + std::to_string(num_split); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmFloat8Params* params) -> Status { - OpA op_a = CreateOp(params->scale_a, params->scale_a_dev); - OpB op_b = CreateOp(params->scale_b, params->scale_b_dev); - OpC op_c = CreateOp(params->scale_c, params->scale_c_dev); - - auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, params->ldc, - op_a, op_b, op_c, num_split); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); - } - } - return ret; -} - -#endif // USE_COMPOSABLE_KERNEL - -template -class GemmFloat8TunableOp : public TunableOp> { - public: - GemmFloat8TunableOp() { -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#else - ORT_ENFORCE(false, "CK is required to support GemmFloat8 computing"); -#endif // USE_COMPOSABLE_KERNEL - } -}; - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu deleted file mode 100644 index 4c691dd18f2e9..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { - -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -namespace internal { -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); -} - -namespace internal { -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances); - -// TODO: The first try of derivation does not going well due to various constraints. -// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( -// std::vector, PassThrough, PassThrough>>>& instances); - -// TODO: The first try of derivation does not going well due to various constraints. -// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( -// std::vector, PassThrough, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, PassThrough, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); - // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: -} - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, PassThrough, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); - // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: -} - -namespace internal { -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); -} - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu deleted file mode 100644 index 49463e58886f8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu +++ /dev/null @@ -1,97 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// The derived version is simply double BBlockTransferSrcScalarPerVector and adjust other values correspondingly -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 8, 4, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 8, 4, 32, 32, 3, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 8, 4, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 12, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 16, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 8, 4, 32, 32, 3, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 8, 4, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu deleted file mode 100644 index 236e5555051fc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu +++ /dev/null @@ -1,80 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu deleted file mode 100644 index 1a0d45df82a71..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu deleted file mode 100644 index a0628802ec09e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu +++ /dev/null @@ -1,97 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc deleted file mode 100644 index 7dbb24463961e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ /dev/null @@ -1,347 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" - -using namespace onnxruntime::common; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasAdd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMulConj); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskBiasDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, NGramRepeatBlock); - -// These ops were experimental ops in onnx domain which have been removed now. We add them here as -// contrib ops to maintain backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LongformerAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LongformerAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Sampling); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_float, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, double_double_double, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_MLFloat16, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Inverse); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MatMulNBits); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GemmFloat8); - -#ifdef ENABLE_ATEN -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); -#endif - -#ifdef ENABLE_TRAINING_OPS -// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or -// 2). this is needed by inference for other purpose. -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather); -#endif - -#ifdef ORT_USE_NCCL -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll); -#endif - -template <> -KernelCreateInfo BuildKernelCreateInfo() { - KernelCreateInfo info; - return info; -} - -// clang-format off -Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { - static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // These ops were experimental ops in onnx domain which have been removed now. We add them here as - // contrib ops to maintain backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // TransposedMatMul is still here for backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - -#ifdef ENABLE_ATEN - BuildKernelCreateInfo, -#endif - -#ifdef ENABLE_TRAINING_OPS - // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or - // 2). this is needed by inference for other purpose. - BuildKernelCreateInfo, -#endif - -#ifdef ORT_USE_NCCL - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, -#endif - - }; - - for (auto& function_table_entry : function_table) { - KernelCreateInfo info = function_table_entry(); - if (info.kernel_def != nullptr) { // filter disabled entries where type is void - ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); - } - } - - return Status::OK(); -} -// clang-format on - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h deleted file mode 100644 index db9a5d4fcd83e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -Status RegisterRocmContribKernels(KernelRegistry& kernel_registry); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index a5ab63d74df24..130dd0c25a880 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -165,7 +165,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let query_pos = m + local_id.y + past_sequence_length;\n" << " let key_pos = n + local_id.x;\n" << " if (key_pos > query_pos) {\n" - << " sum = -3.40282e+38; // Set to very negative value for masking\n" + << " sum = -3.4028234663852886e+38; // Set to very negative value for masking\n" << " }\n"; } @@ -272,7 +272,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let effective_seq_length = seq_causal_length;\n"; } shader.MainFunctionBody() - << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "var thread_max_vector = f32_val_t(-3.4028234663852886e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" << " let actual_pos = local_offset + i + start_offset;\n" << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" @@ -289,7 +289,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } else if (use_smooth_softmax_) { shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n"; } else { - shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n"; + shader.MainFunctionBody() << "var max_value = f32(-3.4028234663852886e+38f);\n"; } shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 606dbfde15c2c..2a67dfdb07912 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -421,7 +421,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co indirect_buffer_ptr, tile_size)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr)); } if (parameters.sequence_length_ > 1) { @@ -571,8 +571,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {head_size_vec}, - {half_rotary_embedding_dim_vec}, + {static_cast(head_size_vec)}, + {static_cast(half_rotary_embedding_dim_vec)}, {present_sequence_length}, {tile_size}, {static_cast(dispatch_size)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index a5922ec9512fd..ff8e4ecc08bab 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -26,7 +26,7 @@ fn get_total_sequence_length() -> u32 { #if is_fp16 const min_value = q_element_t(-65504.0); #else -const min_value = q_element_t(-3.402823e+38f); +const min_value = q_element_t(-3.4028234663852886e+38f); #endif // For max performance max_k_step should be the same as sg_size, however we might run out of registers diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template index c6f768beffa0f..ac9a157492007 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template @@ -93,7 +93,7 @@ $MAIN { if (local_idx == 0u) { // Calculate the max and sum in current split. - var l_max = f32(-3.402823e+38f); + var l_max = f32(-3.4028234663852886e+38f); var l_sum = f32(0); for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { l_max = max(l_max, f32(tile_qk[i])); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template index 37cf7e8f11b1f..a113e96130985 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template @@ -54,7 +54,7 @@ $MAIN { // Calculate the global max and sum in qk. if (head_idx < uniforms.num_heads) { - var g_max = f32(-3.402823e+38f); + var g_max = f32(-3.4028234663852886e+38f); var g_sum = f32(0); for (var i = 0u; i < num_total_seq_length_tile; i++) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 05717fd2fe686..416a895e61745 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -128,8 +128,8 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {head_size_vec}, - {half_rotary_embedding_dim_vec}, + {static_cast(head_size_vec)}, + {static_cast(half_rotary_embedding_dim_vec)}, {static_cast(dispatch_size)}, }) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); diff --git a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template index 1214777009a8d..6e0d4c7299793 100644 --- a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template @@ -18,7 +18,7 @@ const K: u32 = k; #if is_fp16 const MAX_FLOAT: f16 = 65504.0; #else -const MAX_FLOAT: f32 = 3.402823466e+38; +const MAX_FLOAT: f32 = 3.4028234663852886e+38; #endif var shared_vals: array; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index e77496b6e8196..1c80d83f99feb 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -499,8 +499,7 @@ class PlannerImpl { /*! \brief Given a tensor-type, return the size of an element of the tensor. */ static size_t GetElementSize(const DataType& tensor_type) { - const TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); - MLDataType ml_data_type = DataTypeImpl::TypeFromProto(type_proto); + MLDataType ml_data_type = DataTypeImpl::GetDataType(*tensor_type); const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); ORT_ENFORCE(nullptr != tensor_type_base); MLDataType elt_type = tensor_type_base->GetElementType(); diff --git a/onnxruntime/core/framework/ort_value_name_idx_map.h b/onnxruntime/core/framework/ort_value_name_idx_map.h index 76e7e369514d4..6035dc4e85242 100644 --- a/onnxruntime/core/framework/ort_value_name_idx_map.h +++ b/onnxruntime/core/framework/ort_value_name_idx_map.h @@ -33,7 +33,7 @@ class OrtValueNameIdxMap { common::Status GetIdx(std::string_view name, int& idx) const { idx = -1; - auto it = map_.find(std::string(name)); + auto it = map_.find(name); if (it == map_.end()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not find OrtValue with name '", name, "'"); } diff --git a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h index bc52a45adfd43..94ef87fb069af 100644 --- a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h +++ b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h @@ -83,7 +83,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nchw_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input if (rank < 3) { - fail_shape_inference("Output tensor must have at least 3 dimensions"); + *nhwc_tp.mutable_tensor_type()->mutable_shape() = nchw_shape; + return; } // Convert output shape from N, C, H {, W, ...} to N, H {, W, ...}, C. @@ -105,8 +106,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nhwc_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input. if (rank < 3) { - fail_shape_inference( - "Tensor must have at least 3 dimensions to convert between channels first and channels last."); + *nchw_tp.mutable_tensor_type()->mutable_shape() = nhwc_shape; + return; } // Convert input shape from {N, H, W, ..., C} to {N, C, H, W, ...}. diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 6cbbdd4e0a7ef..1eb03af3befa4 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -81,6 +81,10 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(captureState); } +void Telemetry::LogCompileModel(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); +} + void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { ORT_UNUSED_PARAMETER(session_id); diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index b60345e1b8a80..9c2859f7634b6 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -66,6 +66,8 @@ class Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const; + virtual void LogCompileModel(uint32_t session_id) const; + virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 2e5d334856278..693e265af46b1 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -334,6 +334,20 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio } } +void WindowsTelemetry::LogCompileModel(uint32_t session_id) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "CompileModel", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId")); +} + void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { if (global_register_count_ == 0 || enabled_ == false) diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 261d14a7fed8c..044feec071223 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -59,6 +59,8 @@ class WindowsTelemetry : public Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const override; + void LogCompileModel(uint32_t session_id) const override; + void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index ef977161bcc37..26144e6ba3995 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -126,7 +126,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation -JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.4028234663852886e+38f, max, -3.4028234663852886e+38f) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index e2a8005aba1da..d148c4191d5d7 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1407,9 +1407,30 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs; + + // These maps store the inputs and outputs of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map fused_inputs, fused_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map fused_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + std::unordered_set erased; + + // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', + // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. + // Items added earlier receive a smaller order index than items added later. + // When constructing the final sub_graph's input or output lists, entries with smaller + // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -1428,7 +1449,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -1443,7 +1464,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -1464,39 +1485,33 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - if (node_set.find(node_idx) != node_set.end()) { - const auto& iter = fused_inputs.find(output); - if (iter != fused_inputs.end()) { - fused_inputs.erase(iter); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } - fused_outputs[output] = output_order++; - } - } else { - fused_outputs_to_add[output] = output_order++; + + if (node_set.find(node_idx) == node_set.end()) { + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + fused_outputs_to_add.insert({output, output_order++}); } } - } else { - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } - // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list - else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } + } - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - fused_outputs[output] = output_order++; - } + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + fused_outputs.insert({output, output_order++}); } } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } } } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 4d183b95bd938..0bb3accb4d754 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -76,6 +76,9 @@ Status BaseOpBuilder::ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper, return CheckHtpDataTypes(input_qnn_dtypes, output_qnn_dtypes); } else if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) { return CheckGpuDataTypes(input_qnn_dtypes, output_qnn_dtypes); + } else if (IsIrBackend(qnn_model_wrapper.GetQnnBackendType())) { + // TODO: CheckIrDataTypes + return Status::OK(); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Only support backend: CPU, HTP and GPU"); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index f3d81d7d2fdd7..9f28e2609faa1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -574,6 +574,10 @@ bool QnnOpConfigWrapper::CreateQnnGraphOp(const QNN_INTERFACE_VER_TYPE& qnn_inte return true; } +bool IsIrBackend(QnnBackendType backend_type) { + return backend_type == QnnBackendType::SERIALIZER; +} + bool IsNpuBackend(QnnBackendType backend_type) { return backend_type == QnnBackendType::HTP || backend_type == QnnBackendType::DSP; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 42f4d7bb60f34..77508f3934a20 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -96,6 +96,8 @@ enum class QnnBackendType : uint8_t { SERIALIZER, }; +bool IsIrBackend(QnnBackendType backend_type); + bool IsCpuBackend(QnnBackendType backend_type); bool IsNpuBackend(QnnBackendType backend_type); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 85901ab6fdfec..8973a4efa8ba1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -222,14 +222,14 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name(); + const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name(); + const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index cd0c0e4bffdb5..e5b48da33fbc3 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2035,9 +2035,30 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs; + + // These maps store the inputs and outputs of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map fused_inputs, fused_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map fused_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + std::unordered_set erased; + + // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', + // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. + // Items added earlier receive a smaller order index than items added later. + // When constructing the final sub_graph's input or output lists, entries with smaller + // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -2056,7 +2077,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -2071,7 +2092,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -2092,39 +2113,33 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - if (node_set.find(node_idx) != node_set.end()) { - const auto& iter = fused_inputs.find(output); - if (iter != fused_inputs.end()) { - fused_inputs.erase(iter); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } - fused_outputs[output] = output_order++; - } - } else { - fused_outputs_to_add[output] = output_order++; + + if (node_set.find(node_idx) == node_set.end()) { + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + fused_outputs_to_add.insert({output, output_order++}); } } - } else { - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } - // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list - else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } + } - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - fused_outputs[output] = output_order++; - } + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + fused_outputs.insert({output, output_order++}); } } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } } } diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc index 85096d0e262d7..9948069c6779b 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc @@ -78,8 +78,8 @@ bool ClipOpBuilder::HandleBuildOp(vsi::npu::GraphEP* graph_ep, LOGS_DEFAULT(INFO) << "Creating Clip Op."; if (node_unit.SinceVersion() <= 6) { NodeAttrHelper helper(node_unit.GetNode()); - auto min = helper.Get("min", -3.402e+38f); - auto max = helper.Get("max", 3.402e+38f); + auto min = helper.Get("min", -3.4028234663852886e+38f); + auto max = helper.Get("max", 3.4028234663852886e+38f); auto op = graph_ep->GetGraph()->CreateOperation(min, max); (*op).BindInputs(inputs).BindOutputs(outputs); graph_ep->GetOps().push_back(std::move(op)); diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index b3eb4b5061423..3e1b87821fe2f 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -13,7 +13,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), + WebGpuDevice, OrtMemTypeDefault)), buffer_manager_{buffer_manager}, mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 7c38b4557e078..74b3d669fcf3b 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -11,6 +11,11 @@ namespace webgpu { class BufferManager; +inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, + OrtDevice::MemType::DEFAULT, + OrtDevice::VendorIds::NONE, + 0}; + class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index ebe71c6ccfacd..d1a2011c8e191 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -6,22 +6,25 @@ namespace onnxruntime { namespace webgpu { -ComputeContext::ComputeContext(OpKernelContext& kernel_context, - const OpKernel& op_kernel, - const WebGpuExecutionProvider& ep, - WebGpuContext& webgpu_context) + +ComputeContextBase::ComputeContextBase(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel) : webgpu_context_{webgpu_context}, - kernel_context_{kernel_context}, - op_kernel_{op_kernel}, - ep_{ep} { + ep_{ep}, + op_kernel_{op_kernel} { } -const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const ComputeContext& context) { +const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(const ComputeContextBase& context) { return context.ep_.BufferManager(); } -const SplitKConfig& ComputeContext::GetSplitKConfig() { - return webgpu_context_.GetSplitKConfig(); +ComputeContext::ComputeContext(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel, + OpKernelContext& kernel_context) + : ComputeContextBase(webgpu_context, ep, op_kernel), + kernel_context_{kernel_context} { } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index ed16f2f0a1345..fdf89854469d6 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -24,7 +24,13 @@ namespace webgpu { class WebGpuContext; class BufferManager; -class ComputeContext final { +// +// Class ComputeContextBase is designed to provide basic context information +// for running a compute shader program. +// +// An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created. +// +class ComputeContextBase { public: // Nested accessor class to provide controlled access to BufferManager class BufferManagerAccessor { @@ -34,18 +40,31 @@ class ComputeContext final { friend class WebGpuContext; private: - static const webgpu::BufferManager& Get(const ComputeContext& context); + static const webgpu::BufferManager& Get(const ComputeContextBase& context); }; - ComputeContext(OpKernelContext& kernel_context, - const OpKernel& op_kernel, - const WebGpuExecutionProvider& ep, - WebGpuContext& webgpu_context); + ComputeContextBase(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel); - ~ComputeContext() = default; + ~ComputeContextBase() = default; + + // + // Get the node name. + // + inline decltype(auto) NodeName() const { + return op_kernel_.Node().Name(); + } + + // + // Get the operator type. + // + inline decltype(auto) OpType() const { + return op_kernel_.Node().OpType(); + } // - // Get various information from the context. + // Get various information from the WebGPU context. // inline const wgpu::AdapterInfo& AdapterInfo() const { @@ -57,9 +76,6 @@ class ComputeContext final { inline bool HasFeature(wgpu::FeatureName feature) const { return webgpu_context_.DeviceHasFeature(feature); } - inline bool IsGraphCaptureEnabled() const { - return ep_.IsGraphCaptureEnabled(); - } #if !defined(__wasm__) inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const { return webgpu_context_.SubgroupMatrixConfigs(); @@ -67,17 +83,57 @@ class ComputeContext final { #endif // - // Get the kernel context. + // Get Split-K configuration. // - inline OpKernelContext& KernelContext() { - return kernel_context_; + inline const SplitKConfig& GetSplitKConfig() const { + return webgpu_context_.GetSplitKConfig(); + } + + // + // Get whether graph capture is enabled. + // + inline bool IsGraphCaptureEnabled() const { + return ep_.IsGraphCaptureEnabled(); } // // Get the logger. // inline const logging::Logger& Logger() const { - return kernel_context_.Logger(); + return *ep_.GetLogger(); + } + + // + // Run a compute shader program. + // + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } + + protected: + WebGpuContext& webgpu_context_; + const WebGpuExecutionProvider& ep_; + const OpKernel& op_kernel_; +}; + +// +// Class ComputeContext provides all information a `ComputeContextBase` provides, and also +// access to `OpKernelContext` for input and output tensors. +// +class ComputeContext final : public ComputeContextBase { + public: + ComputeContext(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel, + OpKernelContext& kernel_context); + + ~ComputeContext() = default; + + // + // Get the kernel context. + // + inline OpKernelContext& KernelContext() { + return kernel_context_; } // @@ -145,25 +201,8 @@ class ComputeContext final { return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst); } - // - // Run a compute shader program. - // - inline Status RunProgram(const ProgramBase& program) { - return webgpu_context_.Run(*this, program); - } - - // - // Get Split-K configuration. - // - // `split_k_config_` won't be initialized until the first call to this method. - // - const SplitKConfig& GetSplitKConfig(); - private: - WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; - const OpKernel& op_kernel_; - const WebGpuExecutionProvider& ep_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 82645e30082e6..3c974ef5133c0 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -322,11 +322,14 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { round_str = "round"; } - std::string use_sqrt_for_pow; + std::string use_pow_shortcut; if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + // use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0 // use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5 - use_sqrt_for_pow = - " else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" + use_pow_shortcut = + " else if (b == 2.0) {\n" + " return a * a;\n" + " } else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" " return sqrt(a);\n" " }\n"; } @@ -337,7 +340,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { " } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n" " return input_a_element_t(pow(f32(a), b)); // NaN\n" " }\n" - << use_sqrt_for_pow + << use_pow_shortcut << " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n" << "}\n" "fn pow_v(a : vec4, b : vec4) -> vec4 {\n" diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 6aefa90a59285..c26b58a7af1f4 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -93,18 +93,21 @@ Status ApplyGemmPacked(const Tensor* a, } const uint32_t TILE_SIZE = 32; - const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE; program.CacheHint(alpha, transA, transB, c_is_scalar) .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) - .SetDispatchGroupSize(num_tile_n, num_tile_m, 1) + .SetDispatchGroupSize(dispatch_x, dispatch_y, 1u) .SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z) .AddUniformVariables({{alpha}, {beta}, - {M}, /* dim_a_outer */ - {N}, /* dim_b_outer */ - {K}} /*dim_inner */ + {M}, /* dim_a_outer */ + {N}, /* dim_b_outer */ + {K}, /*dim_inner */ + {dispatch_x}, /* logical_dispatch_x */ + {dispatch_y}, /* logical_dispatch_y */ + {1u}} /* logical_dispatch_z */ ); return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.h b/onnxruntime/core/providers/webgpu/math/gemm_packed.h index dce5164693aa8..cb89ccefba313 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.h @@ -32,7 +32,10 @@ class GemmProgram final : public Program { {"beta", ProgramUniformVariableDataType::Float32}, {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 7cbc7f6a4a821..89718149cea88 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -117,6 +117,20 @@ void HandleMatMulWithSplitK( } } +// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in +// `ProgramBase.SetDispatchGroupSize()` may be normalized in +// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use +// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. +void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) { + shader.MainFunctionBody() + << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" + << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" + << " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n" + << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; +} + } // namespace void MatMulReadFnSource(ShaderHelper& shader, @@ -274,20 +288,22 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << "const innerElementSize = " << inner_elements_size << ";\n" << "const tileInner = " << tile_inner << ";\n"; + InitializeLogicalWorkgroupIDAndGlobalID(shader); + shader.MainFunctionBody() << " let localRow = i32(local_id.y);\n" << " let tileRow = localRow * rowPerThread;\n" << " let tileCol = i32(local_id.x);\n" - << " let globalRow = i32(global_id.y) * rowPerThread;\n" - << " let globalCol = i32(global_id.x);\n" - << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << " let globalRow = i32(logical_global_id.y) * rowPerThread;\n" + << " let globalCol = i32(logical_global_id.x);\n" + << " let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" + << " let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" << " var acc: array, rowPerThread>;\n"; if (split_k) { // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from - // `kSplitK * i32(global_id.z)`. + // `kSplitK * i32(logical_global_id.z)`. // // For example: considering computing Y = (X * W + B) in one workgroup. // Let kSplitK = 2, B = [d1, d2] @@ -305,15 +321,15 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) // Workgroup3: compute (C1 * C2) // In each workgroup: - // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z` + // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `logical_global_id.z` // - When the computation in each workgroup is completed, add the result to Y with several // atomic built-in functions in `HandleMatMulWithSplitK()`. shader.MainFunctionBody() << "const kSplitK = " << split_dim_inner << ";\n" << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" - << " var kStart = kSplitK * i32(global_id.z);\n" + << " var kStart = kSplitK * i32(logical_global_id.z);\n" - // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate + // When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate // the index of split-k instead of batch. << " let batch = 0;\n" << " let batchIndices = 0u;\n"; @@ -321,7 +337,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, shader.MainFunctionBody() << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" - << " let batch = i32(global_id.z);\n" + << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : ""); } @@ -498,7 +514,9 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, << "const colPerThread = " << elements_per_thread_x << ";\n" << "const tileInner = " << tile_inner << ";\n"; - shader.MainFunctionBody() << " let batch = i32(global_id.z);\n" + InitializeLogicalWorkgroupIDAndGlobalID(shader); + + shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" @@ -507,10 +525,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, shader.MainFunctionBody() << "let tileRow = i32(local_id.y) * rowPerThread;\n" << "let tileCol = i32(local_id.x) * colPerThread;\n" - << "let globalRow = i32(global_id.y) * rowPerThread;\n" - << "let globalCol = i32(global_id.x) * colPerThread;\n" - << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << "let globalRow = i32(logical_global_id.y) * rowPerThread;\n" + << "let globalCol = i32(logical_global_id.x) * colPerThread;\n" + << "let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" + << "let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 55c2c5773cc1f..72dd235eb820a 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -256,8 +256,6 @@ Status ComputeMatMul(ComputeContext* context, // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the // number of splits along `dim_inner`. - // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize - // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`. split_dim_inner = split_k_config.GetSplitDimInner(); dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; @@ -271,7 +269,7 @@ Status ComputeMatMul(ComputeContext* context, .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) - .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z) diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 143ba61c99e13..dbd193bc38f58 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -24,7 +24,10 @@ class MatMulProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); bool NeedSplitK() const; diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc index 2f34aa21c8309..bf3bb53341418 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -64,7 +64,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { int components = input.NumComponents(); const std::string thread_max_decl = is_fp32_ - ? "var thread_max = x_value_t(-3.402823e+38f);\n" + ? "var thread_max = x_value_t(-3.4028234663852886e+38f);\n" : "var thread_max = x_value_t(-65504.0h);\n"; // Define shared memory for row max and row sum diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 77fa46cb87518..4fff736fd2f32 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -216,6 +216,46 @@ Status Conv::ComputeInternal(ComputeContext& context return context.RunProgram(conv2d_mm_program); } +template +Status Conv::PrePackInternal(ComputeContextBase& /* context */, + const Tensor& tensor, + int input_idx, + AllocatorPtr /* alloc */, + /*out*/ bool& is_packed) { + is_packed = false; + + if constexpr (is_channels_last) { + if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) { + // only deal with 4D NHWC weights + + // TODO: implement weight transpose for pre-pack here + // Conv::ComputeInternal() should be updated to reflect the change: + // - if the initializer is packed, `context.Input(1)` will be nullptr. + // - in this case, use `transposed_kernel_` instead. + + // // Step.1 - calculate transposed weight shape + // TensorShape transposed_kernel_shape{tensor.Shape()[2], + // tensor.Shape()[3], + // tensor.Shape()[1], + // tensor.Shape()[0]}; + + // // Step.2 - create transposed weight tensor + // transposed_kernel_ = std::make_unique(tensor.DataType(), transposed_kernel_shape, alloc); + + // // Step.3 - do transpose + // size_t perm[] = {2, 3, 1, 0}; + // ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, + // perm, + // tensor, + // *transposed_kernel_)); + + // is_packed = true; // set this flag to true so that ORT will release the initializer tensor + } + } + + return Status::OK(); +} + // Explicit template instantiation for FusedConv template class Conv; template class Conv; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h index cafaa272c0613..5bf94a459a44a 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.h +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -23,9 +23,16 @@ class Conv : public WebGpuKernel { } Status ComputeInternal(ComputeContext& context) const override; + Status PrePackInternal(ComputeContextBase& context, + const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed) override; + protected: ConvAttributes conv_attrs_; Activation activation_; + std::unique_ptr transposed_kernel_; // should only have value when `is_initializer` AND `is_4D` AND `is_NHWC` }; Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc index 2d5424c52a3f2..c66f2cbd582d9 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc @@ -226,7 +226,10 @@ Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::v {static_cast(dim_inner)}, {pads}, {strides}, - {dilations}}); + {dilations}, + {dispatch[0]}, + {dispatch[1]}, + {dispatch[2]}}); return program; } diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h index d7cc08aae26f3..e161bffb0c503 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h @@ -38,7 +38,10 @@ class Conv2dMMProgram final : public Program { {"dim_inner", ProgramUniformVariableDataType::Uint32}, {"pads", ProgramUniformVariableDataType::Uint32}, {"strides", ProgramUniformVariableDataType::Uint32}, - {"dilations", ProgramUniformVariableDataType::Uint32}); + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); private: const Activation& activation_; diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 7e8b434431781..5f59fecc425e2 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -92,14 +92,28 @@ Status SliceProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +static std::vector getInt64Input(const Tensor* tensor) { + if (tensor->IsDataType()) { + return std::vector(tensor->DataAsSpan().begin(), tensor->DataAsSpan().end()); + } + ORT_ENFORCE(tensor->IsDataType(), "Expected tensor of type int32 or int64"); + std::vector result; + auto span = tensor->DataAsSpan(); + result.reserve(span.size()); + for (auto v : span) { + result.push_back(static_cast(v)); + } + return result; +} + Status Slice::ComputeInternal(ComputeContext& context) const { // READ INPUTS const Tensor* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); auto input_rank = input_shape.NumDimensions(); - auto starts_raw = attr_starts_.empty() ? context.Input(1)->DataAsSpan() : gsl::make_span(attr_starts_); - auto ends_raw = attr_ends_.empty() ? context.Input(2)->DataAsSpan() : gsl::make_span(attr_ends_); + auto starts_raw = attr_starts_.empty() ? getInt64Input(context.Input(1)) : attr_starts_; + auto ends_raw = attr_ends_.empty() ? getInt64Input(context.Input(2)) : attr_ends_; ORT_ENFORCE(starts_raw.size() == ends_raw.size(), "starts and ends must have the same size"); @@ -126,7 +140,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { axes_default.push_back(i); } } - auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? gsl::make_span(axes_default) : axes_tensor->DataAsSpan()) : gsl::make_span(attr_axes_); + auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? axes_default : getInt64Input(axes_tensor)) : attr_axes_; std::vector steps_default; if (steps_tensor == nullptr) { @@ -135,7 +149,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { steps_default.push_back(1); } } - auto steps_raw = steps_tensor == nullptr ? gsl::make_span(steps_default) : steps_tensor->DataAsSpan(); + auto steps_raw = steps_tensor == nullptr ? steps_default : getInt64Input(steps_tensor); // get final axes std::vector axes, axes_fixed; diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index cec321d0da80e..5415d4a5ead5b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -108,7 +108,7 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, +Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output) { const auto& input_shape = input.Shape(); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index b62a419fa12bc..5e9ccc6750cd6 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase { Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { } Status ComputeInternal(ComputeContext& context) const override; - static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output); + static Status DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output); constexpr static uint32_t TILE_SIZE = 16; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 28decb076951e..b8d5adc421124 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -147,6 +147,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // create program manager program_mgr_ = std::make_unique(*this); + // create split-k config + split_k_config_ = std::make_unique(adapter_info_); + // set query type #if !defined(__wasm__) if (DeviceHasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { @@ -178,7 +181,7 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } -Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { +Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); @@ -288,8 +291,8 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch); if (is_profiling_) { - PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(), - context.KernelContext().GetOpType(), + PendingKernelInfo pending_kernel_info(context.NodeName(), + context.OpType(), program.Name(), key, inputs, @@ -442,7 +445,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; WGPUBuffer uniform_buffer = nullptr; - const webgpu::BufferManager& buffer_mgr = ComputeContext::BufferManagerAccessor::Get(context); + const webgpu::BufferManager& buffer_mgr = ComputeContextBase::BufferManagerAccessor::Get(context); if (uniform_buffer_total_size > 0) { std::vector uniform_data_buffer(uniform_buffer_total_size); @@ -910,13 +913,6 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; std::once_flag WebGpuContextFactory::init_default_flag_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index bd7dae75f2e2d..84dfb47ef4687 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,7 +5,6 @@ #include #include -#include #include "core/providers/webgpu/webgpu_external_header.h" @@ -23,7 +22,7 @@ class Tensor; namespace webgpu { class WebGpuContext; -class ComputeContext; +class ComputeContextBase; class ProgramBase; // Definition for CapturedCommandInfo in the webgpu namespace @@ -152,6 +151,13 @@ class WebGpuContext final { return validation_mode_; } + // + // Get Split-K configuration. + // + const SplitKConfig& GetSplitKConfig() const { + return *split_k_config_; + } + void StartProfiling(); void CollectProfilingData(profiling::Events& events); void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); @@ -170,16 +176,9 @@ class WebGpuContext final { // Status PopErrorScope(); - Status Run(ComputeContext& context, const ProgramBase& program); + Status Run(ComputeContextBase& context, const ProgramBase& program); void OnRunEnd(); - // - // Get Split-K configuration. - // - // `split_k_config_` won't be initialized until the first call to this method. - // - const SplitKConfig& GetSplitKConfig(); - private: enum class TimestampQueryType { None = 0, @@ -277,7 +276,7 @@ class WebGpuContext final { uint32_t num_pending_dispatches_ = 0; const uint32_t max_num_pending_dispatches_ = 16; - std::optional split_k_config_; + std::unique_ptr split_k_config_; // profiling TimestampQueryType query_type_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index e0b84fef51f1f..6b764d51bcf75 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -794,8 +794,7 @@ using namespace webgpu; WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, WebGpuContext& context, WebGpuExecutionProviderConfig&& config) - : IExecutionProvider{kWebGpuExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)}, + : IExecutionProvider{kWebGpuExecutionProvider, WebGpuDevice}, context_id_{context_id}, context_{context}, preferred_data_layout_{config.data_layout}, @@ -935,13 +934,14 @@ std::unique_ptr WebGpuExecutionProvider::GetEx std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, std::string_view node_op_type, DataLayout target_data_layout) const { - if (target_data_layout != DataLayout::NHWC) { - return std::nullopt; - } - // NHWC for Resize operator is not implemented on kWebGpuExecutionProvider if (node_domain == kOnnxDomain && node_op_type == "Resize") { - return false; + return target_data_layout != DataLayout::NHWC; + } + + // WebGPU perfer NCHW for InstanceNormalization due to a better performance + if (node_domain == kOnnxDomain && node_op_type == "InstanceNormalization") { + return target_data_layout != DataLayout::NHWC; } return std::nullopt; diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc index 8d6ae6caeaf83..ea38e9415e1fe 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc @@ -11,25 +11,58 @@ namespace webgpu { WebGpuKernel::WebGpuKernel(const OpKernelInfo& info) : OpKernel(info), - ep_(*static_cast(info.GetExecutionProvider())) { + ep_(*static_cast(info.GetExecutionProvider())), + webgpu_context_(WebGpuContextFactory::GetContext(ep_.GetDeviceId())) { } Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const { - WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId()); - ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context}; + ComputeContext context{webgpu_context_, + ep_, + *this, + *p_op_kernel_context}; - if (webgpu_context.ValidationMode() >= ValidationMode::Full) { - webgpu_context.PushErrorScope(); + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + webgpu_context_.PushErrorScope(); } Status s = ComputeInternal(context); - if (webgpu_context.ValidationMode() >= ValidationMode::Full) { - ORT_RETURN_IF_ERROR(webgpu_context.PopErrorScope()); + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); } return s; } +Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { + ComputeContextBase context{webgpu_context_, ep_, *this}; + + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + webgpu_context_.PushErrorScope(); + } + + // Currently, ORT does not allow using prepacked weights in non-CPU EPs. + // So we do not pass prepacked_weights to PrePackInternal. + // Kernel implementation that supports prepacking should manage its own storage. + + Status s = PrePackInternal(context, tensor, input_idx, alloc, is_packed); + + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); + } + + return s; +} + +Status WebGpuKernel::PrePackInternal(ComputeContextBase& /*context*/, + const Tensor& /*tensor*/, + int /*input_idx*/, + AllocatorPtr /*alloc*/, + /*out*/ bool& is_packed) { + is_packed = false; + return Status::OK(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 3c750e305421c..2c57991c6ee35 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -23,8 +23,41 @@ class WebGpuKernel : public OpKernel { virtual Status ComputeInternal(ComputeContext& context) const = 0; + // Overrides OpKernel::PrePack to handle constant tensor pre-processing for WebGPU kernels. + // This method creates a ComputeContextBase and delegates to PrePackInternal. + // + // NOTE: Currently, ORT does not allow using prepacked weights in non-CPU EPs, so the + // prepacked_weights parameter is not passed to PrePackInternal. Kernel implementations + // that support prepacking should manage their own storage. + Status PrePack(const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + // Virtual method that allows derived kernels to pre-process constant tensors during initialization. + // + // This method is called during kernel initialization when constant tensors are available, + // allowing kernels to perform operations like tensor transposition or format conversion + // before the first Compute call. + // + // @param context The WebGPU compute context base providing access to the execution environment. + // @param tensor The constant tensor to potentially pre-process. + // @param input_idx The index of this input in the kernel's input list. + // @param alloc The allocator to use for any new tensor allocations. + // @param is_packed Output parameter. Set to true if the tensor was pre-packed/processed, + // false otherwise. The default implementation sets this to false. + // + // @return Status::OK() on success, or an error status on failure. + virtual Status PrePackInternal(ComputeContextBase& context, + const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed); + private: const WebGpuExecutionProvider& ep_; + WebGpuContext& webgpu_context_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 568d29a96cb88..5fd24b2bff037 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -21,27 +21,24 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components return TensorShape(shape_vector); } -SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) { - SplitKConfig config = {}; - +SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) { if (adapter_info.vendor == std::string_view{"intel"}) { if (adapter_info.architecture == std::string_view{"xe-2lpg"} || adapter_info.architecture == std::string_view{"xe-2hpg"} || adapter_info.architecture == std::string_view{"xe-lpg"} || adapter_info.architecture == std::string_view{"gen-12hp"}) { - config.enable_split_k_ = true; + enable_split_k_ = true; // Below thresholds are only verified on the above Intel GPUs without any regressions. The // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more // atomic calls for each output value. - config.split_dim_inner_ = 256; - config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; - config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; - config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; + split_dim_inner_ = 256; + min_dim_inner_with_split_k_ = split_dim_inner_ * 2; + max_dim_inner_with_split_k_ = split_dim_inner_ * 9; + max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; } } - return config; } bool SplitKConfig::UseSplitK( diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index d45b9bf4dd119..7d5ab5fea8006 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -91,9 +91,12 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; } +/** + * Configuration for Split-K optimization (Conv|MatMul). + */ class SplitKConfig { public: - static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info); + explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4d4dea9cb444c..ab3932e7abfb4 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2943,6 +2943,8 @@ Status InferenceSession::Run(const RunOptions& run_options, << cached_execution_provider_for_graph_replay_.Type() << " CUDA Graph for this model with tag: " << run_options.run_tag << " with graph annotation id: " << graph_annotation_id; + // log evaluation start to trace logging provider + env.GetTelemetryProvider().LogEvaluationStart(session_id_); ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id)); } else { InlinedVector exec_providers_to_stop; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 6189e6ca7f012..4cb21b80109c8 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -404,6 +404,7 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model session))); } + Env::Default().GetTelemetryProvider().LogCompileModel(session->GetCurrentSessionId()); ORT_RETURN_IF_ERROR(ToStatusAndRelease(InitializeSession(session_options, *session))); return Status::OK(); } diff --git a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc index 70c7a5b2bcdcb..5deef01cd783e 100644 --- a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc +++ b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc @@ -22,10 +22,17 @@ namespace test { // --------- Helpers --------- +// cuda errors are sticky and may affect subsequent API calls. +// we want to clear the error if when supported check fails. +void ClearCudaError() { + ORT_IGNORE_RETURN_VALUE(::cudaGetLastError()); +} + static bool IsCudaMemPoolSupported() { int ort_cuda_rt_version = 0; cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version); if (cuda_status != cudaSuccess) { + ClearCudaError(); return false; } @@ -36,6 +43,7 @@ static bool IsCudaMemPoolSupported() { int ort_cuda_driver_version = 0; cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version); if (cuda_status != cudaSuccess) { + ClearCudaError(); return false; } @@ -65,9 +73,10 @@ static bool IsCudaMemPoolSupported() { cudaMemPool_t pool; auto cuda_error = cudaMemPoolCreate(&pool, &props); if (cuda_error != cudaSuccess) { + ClearCudaError(); return false; } - cuda_error = cudaMemPoolDestroy(pool); + ORT_IGNORE_RETURN_VALUE(cudaMemPoolDestroy(pool)); return true; } @@ -80,7 +89,9 @@ static ::cudaStream_t NewCudaStream() { } static void DestroyCudaStream(::cudaStream_t s) { - if (s) (void)::cudaStreamDestroy(s); + if (s) { + EXPECT_EQ(cudaSuccess, ::cudaStreamDestroy(s)); + } } static void TouchDevice(void* p, size_t bytes, ::cudaStream_t s, unsigned char value = 0xAB) { diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index d8cc56d738175..af9706855ee3c 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -203,6 +203,48 @@ TEST_P(TypeTests, IOTypes) { } } +TEST(NvExecutionProviderTest, TestSessionOutputs) { + /* + * Model #1: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + { + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + + auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 4); + } + + /* + * Model #2: + * + * "X" ---> Dropout ---> MatMul ---> "Y" + * ^ | + * | | + * "W" ------ ----> Can't be graph's output + * + */ + { + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + + auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 1); + } +} + INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, ::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, diff --git a/onnxruntime/test/providers/qnn/README.md b/onnxruntime/test/providers/qnn/README.md new file mode 100644 index 0000000000000..c3d0c720a1aa4 --- /dev/null +++ b/onnxruntime/test/providers/qnn/README.md @@ -0,0 +1,70 @@ +# ONNX Runtime QNN Execution Provider Tests +## Overview +1. The `onnxruntime/test/providers/qnn` directory contains integration tests for the Qualcomm Neural Network (QNN) execution provider. +2. Most testcases run an ONNX model through the QNN-EP, then verifies the inference result against the one on CPU-EP + +## Building the Tests +The tests are built as part of the regular ONNX Runtime build. After a successful build you will have an executable named +- onnxruntime_provider_test.exe (Windows) +- onnxruntime_provider_test (Linux/macOS) + +## Running the Tests +1. QNN supports several backends. You can use the standard Google‑Test syntax for filtering: + - `onnxruntime_provider_test.exe --gtest_filter=QnnCPUBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnHTPBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnGPUBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnIRBackendTests.*` +2. Saving Test Artifacts + - For debugging it is often helpful to keep the intermediate files that the tests generate. The following environment + variables are recognized by the test binary: + - `QNN_DUMP_ONNX`: Saves the input ONNX model used for the test + - `QNN_DUMP_JSON`: Save json qnn graph with provider_option `dump_json_qnn_graph` + - `QNN_DUMP_DLC`: Saves the compiled QNN DLC file by specifying the provider_option `backend_path` to `QnnIr.dll` + - The artifacts will be saved to a directory named with `_` + ``` + . + ├── QnnCPUBackendTests_BatchNorm2D_fp32 # RunQnnModelTest + │ ├── dumped_f32_model.onnx # float32 ONNX model + │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + │ └── QNNExecutionProvider_QNN_XXXX_X_X.json + ├── QnnHTPBackendTests_BatchNorm_FP16 # TestFp16ModelAccuracy + │ ├── dumped_f16_model.onnx # float16 ONNX model + │ ├── dumped_f32_model.onnx # float32 ONNX model + │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + │ └── QNNExecutionProvider_QNN_XXXX_X_X.json + └── QnnHTPBackendTests_BatchNorm2D_U8U8S32 # TestQDQModelAccuracy + ├── dumped_f32_model.onnx # float32 ONNX model + ├── dumped_qdq_model.onnx # QDQ ONNX model + ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + └── QNNExecutionProvider_QNN_XXXX_X_X.json + + # All artifact files are placed under the current working directory from which the test binary is invoked. + ``` +3. Verbose + - `QNN_VERBOSE`: Sets the ONNX Runtime log level to `ORT_LOGGING_LEVEL_VERBOSE` + +4. You can enable any combination of these environment variables, for example: + - On Linux/macOS + ```bash + export QNN_DUMP_ONNX=1 + export QNN_DUMP_JSON=1 + export QNN_DUMP_DLC=1 + export QNN_VERBOSE=1 + ``` + - On Windows + ```cmd + set QNN_DUMP_ONNX=1 + set QNN_DUMP_JSON=1 + set QNN_DUMP_DLC=1 + set QNN_VERBOSE=1 + ``` + ```ps1 + $Env:QNN_DUMP_ONNX = "1" + $Env:QNN_DUMP_JSON = "1" + $Env:QNN_DUMP_DLC = "1" + $Env:QNN_VERBOSE = "1" + ``` + +# Note +- An issue on QNN backends can prevent the test artifacts from being successfully saved. +- The `onnxruntime_provider_test.exe` does not automatically delete the artifact directories, so you may want to prune them after a debugging session. diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 1c70f4012090e..15a9132aaa16c 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -101,6 +101,12 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, std::function* ep_graph_checker) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_json() || + QNNTestEnvironment::GetInstance().dump_dlc()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; verification_params.fp32_abs_err = fp32_abs_err; @@ -110,6 +116,10 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -123,7 +133,27 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); + + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); + } + TryEnableQNNSaver(provider_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + provider_options["dump_qnn_ir_dlc"] = "1"; + provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = output_dir.string(); + } RunAndVerifyOutputsWithEP(AsByteSpan(model_data.data(), model_data.size()), "QNN_EP_TestLogID", QnnExecutionProviderWithOptions(provider_options), helper.feeds_, verification_params, @@ -134,11 +164,21 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, logging::Severity log_severity, std::function* ep_graph_checker) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -152,7 +192,27 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); + + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); + } + TryEnableQNNSaver(provider_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + provider_options["dump_qnn_ir_dlc"] = "1"; + provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = output_dir.string(); + } SessionOptions so; so.session_logid = "QNN_EP_TestLogID"; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index aeb3a9a114871..4d4f795d161b1 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -499,6 +499,77 @@ struct QDQTolerance { float value; }; +class QNNTestEnvironment { + public: + // Delete copy constructor and assignment operator + QNNTestEnvironment(const QNNTestEnvironment&) = delete; + QNNTestEnvironment& operator=(const QNNTestEnvironment&) = delete; + + // Static method to get the singleton instance + static QNNTestEnvironment& GetInstance() { + static QNNTestEnvironment instance; + return instance; + } + + bool dump_onnx() const { return dump_onnx_; } + bool dump_json() const { return dump_json_; } + bool dump_dlc() const { return dump_dlc_; } + bool verbose() const { return verbose_; } + + std::filesystem::path CreateTestcaseDirs() { + std::string test_suite_name = ::testing::UnitTest::GetInstance()->current_test_info()->test_suite_name(); + std::string test_name = ::testing::UnitTest::GetInstance()->current_test_info()->name(); + std::filesystem::path output_dir = std::filesystem::current_path() / (test_suite_name + "_" + test_name); + std::filesystem::create_directories(output_dir); + + return output_dir; + } + + private: + // Private constructor for singleton + QNNTestEnvironment() { + ParseEnvironmentVars(); + } + + // Helper function to check if an environment variable is set + bool IsEnvVarSet(const char* name) { + const char* value = std::getenv(name); + if (value == nullptr) { + return false; + } + + // Consider the variable set if it's not empty and not "0" + return *value != '\0' && *value != '0'; + } + + void ParseEnvironmentVars() { + if (IsEnvVarSet("QNN_DUMP_ONNX")) { + std::cout << "[QNN only] ONNX model dumping enabled via environment variable." << std::endl; + dump_onnx_ = true; + } + + if (IsEnvVarSet("QNN_DUMP_JSON")) { + std::cout << "[QNN only] Json QNN Graph dumping enabled via environment variable." << std::endl; + dump_json_ = true; + } + + if (IsEnvVarSet("QNN_DUMP_DLC")) { + std::cout << "[QNN only] DLC dumping enabled via environment variable." << std::endl; + dump_dlc_ = true; + } + + if (IsEnvVarSet("QNN_VERBOSE")) { + std::cout << "Verbose enabled via environment variable." << std::endl; + verbose_ = true; + } + } + + bool dump_onnx_ = false; + bool dump_json_ = false; + bool dump_dlc_ = false; + bool verbose_ = false; +}; + /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: * @@ -529,15 +600,21 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}, std::function* qnn_ep_graph_checker = nullptr) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); - - // Uncomment to dump LOGGER() output to stdout. - // logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -551,8 +628,11 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); - // Uncomment to save f32 model to disk for debugging. - // ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, ToPathString("cmp_accuracy.f32.onnx"))); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); + } // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; @@ -594,11 +674,27 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(qdq_model.MainGraph().Resolve()); qdq_model.ToProto().SerializeToString(&qdq_model_data); - // Uncomment to save QDQ model to disk for debugging. - // ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, ToPathString("cmp_accuracy.qdq.onnx"))); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_qdq_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx QDQ model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, dump_path)); + } bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + qnn_options["dump_qnn_ir_dlc"] = "1"; + qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + qnn_options["dump_json_qnn_graph"] = "1"; + qnn_options["json_qnn_graph_dir"] = output_dir.string(); + } std::vector qnn_qdq_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; @@ -743,11 +839,21 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, logging::Severity log_severity = logging::Severity::kERROR, const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -760,6 +866,12 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); + } + // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; InferenceModel(f32_model_data, "f32_model_logger", {}, ExpectedEPNodeAssignment::All, @@ -796,8 +908,27 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f16_model.MainGraph().Resolve()); f16_model.ToProto().SerializeToString(&f16_model_data); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f16_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float16 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f16_model, dump_path)); + } + bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + qnn_options["dump_qnn_ir_dlc"] = "1"; + qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + qnn_options["dump_json_qnn_graph"] = "1"; + qnn_options["json_qnn_graph_dir"] = output_dir.string(); + } std::vector qnn_f16_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 6a6545c68cb4f..dce0d570ec238 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -1,5 +1,6 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "onnxruntime_cxx_api.h" #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" @@ -18,6 +19,8 @@ using namespace std; using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; +extern std::unique_ptr ort_env; + namespace onnxruntime { namespace test { @@ -1360,5 +1363,49 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); } + +TEST(TensorrtExecutionProviderTest, TestSessionOutputs) { + /* + * Model #1: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + { + OrtTensorRTProviderOptionsV2 provider_options; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_TensorRT_V2(provider_options); + + auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 4); + } + + /* + * Model #2: + * + * "X" ---> Dropout ---> MatMul ---> "Y" + * ^ | + * | | + * "W" ------ ----> Can't be graph's output + * + */ + { + OrtTensorRTProviderOptionsV2 provider_options; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_TensorRT_V2(provider_options); + + auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 1); + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/node_output_not_used.onnx b/onnxruntime/test/testdata/node_output_not_used.onnx new file mode 100644 index 0000000000000..e2726182fddc2 Binary files /dev/null and b/onnxruntime/test/testdata/node_output_not_used.onnx differ diff --git a/onnxruntime/test/testdata/node_output_not_used.py b/onnxruntime/test/testdata/node_output_not_used.py new file mode 100644 index 0000000000000..d36d5e9cfd2f8 --- /dev/null +++ b/onnxruntime/test/testdata/node_output_not_used.py @@ -0,0 +1,43 @@ +import onnx +from onnx import TensorProto, helper + + +def create_model_with_node_output_not_used(model_path): + # Create graph + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) + w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + # Dropout node (two outputs) + dropout_node = helper.make_node( + "Dropout", + inputs=["X"], + outputs=["dropout_out", "dropout_mask"], + name="DropoutNode", + ) + + # MatMul node + matmul_node = helper.make_node( + "MatMul", + inputs=["dropout_out", "W"], + outputs=["Y"], + name="MatMulNode", + ) + + graph = helper.make_graph( + nodes=[dropout_node, matmul_node], + name="DropoutMatMulGraph", + inputs=[x, w], + outputs=[y], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)]) + + onnx.checker.check_model(model) + onnx.save(model, model_path) + + print(f"Model saved to: {model_path}") + + +if __name__ == "__main__": + create_model_with_node_output_not_used("node_output_not_used.onnx") diff --git a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx new file mode 100644 index 0000000000000..340c3d420d574 Binary files /dev/null and b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx differ diff --git a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py new file mode 100644 index 0000000000000..232abb2ed9163 --- /dev/null +++ b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py @@ -0,0 +1,78 @@ +import onnx +from onnx import TensorProto, helper + + +def create_model_with_topk_graph_output(model_path): + # ====================== + # ---- Inputs ---- + # ====================== + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["N"]) + + # ====================== + # ---- Initializers ---- + # ====================== + k = helper.make_tensor("K", TensorProto.INT64, dims=[1], vals=[300]) + zero = helper.make_tensor("zero", TensorProto.INT64, dims=[], vals=[0]) + twenty_six = helper.make_tensor("twenty_six", TensorProto.INT64, dims=[], vals=[26]) + + # ====================== + # ---- Nodes ---- + # ====================== + topk_node = helper.make_node( + "TopK", + inputs=["input", "K"], + outputs=["scores", "topk_indices"], + name="TopK", + ) + + less_node = helper.make_node( + "Less", + inputs=["topk_indices", "zero"], + outputs=["Less_output_0"], + name="Less", + ) + + div_node = helper.make_node( + "Div", + inputs=["topk_indices", "twenty_six"], + outputs=["Div_17_output_0"], + name="Div", + ) + + mod_node = helper.make_node( + "Mod", + inputs=["topk_indices", "twenty_six"], + outputs=["labels"], + name="Mod", + ) + + # ========================= + # ---- Graph Outputs ---- + # ========================= + scores_out = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["K"]) + less_out = helper.make_tensor_value_info("Less_output_0", TensorProto.BOOL, ["K"]) + div_out = helper.make_tensor_value_info("Div_17_output_0", TensorProto.INT64, ["K"]) + labels_out = helper.make_tensor_value_info("labels", TensorProto.INT64, ["K"]) + + # ====================== + # ---- Graph ---- + # ====================== + graph = helper.make_graph( + nodes=[topk_node, less_node, div_node, mod_node], + name="TopKGraph", + inputs=[input_tensor], + outputs=[scores_out, less_out, div_out, labels_out], + initializer=[k, zero, twenty_six], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)]) + + # Validate + Save + onnx.checker.check_model(model) + onnx.save(model, model_path) + + print(f"Model saved to: {model_path}") + + +if __name__ == "__main__": + create_model_with_topk_graph_output("topk_and_multiple_graph_outputs.onnx") diff --git a/tools/ci_build/github/linux/build_rocm_c_api_package.sh b/tools/ci_build/github/linux/build_rocm_c_api_package.sh deleted file mode 100755 index 3ea90c73342a5..0000000000000 --- a/tools/ci_build/github/linux/build_rocm_c_api_package.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash - -set -e -u -x - -usage() { echo "Usage: $0 -S -B -V [-H ] " 1>&2; exit 1; } - -ROCM_HOME=/opt/rocm - -while getopts S:B:V:H:I:P: parameter_Option; do - case "${parameter_Option}" in - S) SOURCE_DIR=${OPTARG};; - B) BINARY_DIR=${OPTARG};; - V) ROCM_VERSION=${OPTARG};; - H) ROCM_HOME=${OPTARG};; - I) IMAGE=${OPTARG};; - P) PYTHON_BIN=${OPTARG};; - *) usage ;; - esac -done - -EXIT_CODE=1 - -docker run -e SYSTEM_COLLECTIONURI --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --user $UID:$(id -g $USER) \ - -e NIGHTLY_BUILD \ - --volume $SOURCE_DIR:/onnxruntime_src \ - --volume $BINARY_DIR:/build \ - --volume /data/models:/build/models:ro \ - --volume /data/onnx:/data/onnx:ro \ - --workdir /onnxruntime_src \ - $IMAGE \ - /bin/bash -c "${PYTHON_BIN:-python} /onnxruntime_src/tools/ci_build/build.py --config Release --build_dir /build --parallel --use_rocm --use_binskim_compliant_compile_flags --rocm_version=$ROCM_VERSION --rocm_home $ROCM_HOME --nccl_home $ROCM_HOME --build_shared_lib --skip_submodule_sync --skip_tests --cmake_extra_defines FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER && cd /build/Release && make install DESTDIR=/build/installed" - - -EXIT_CODE=$? - -set -e -exit $EXIT_CODE diff --git a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh deleted file mode 100755 index 0be64d96f3a34..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -set -e -x - -# version -ROCM_VERSION=6.2.3 - -while getopts "r:" parameter_Option -do case "${parameter_Option}" -in -r) ROCM_VERSION=${OPTARG};; -esac -done - -tee /etc/yum.repos.d/amdgpu.repo <