From ab4831d83b7ed85f4c6192be38b9c59365e4e468 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 25 Nov 2025 08:10:21 -0800 Subject: [PATCH 01/20] Fix npm audit vulnerabilities in /js directory (#26632) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Resolved all security vulnerabilities in JavaScript packages under `/js` by running `npm audit fix`. All updates are non-breaking patch/minor version bumps. **Fixed vulnerabilities:** - `/js` root: 1 high severity - `glob` 10.4.5 → 10.5.0 (command injection - GHSA-5j98-mcp5-4vw2) - `/js/react_native`: 7 vulnerabilities (1 high, 3 moderate, 3 low) - `image-size` → 1.2.1 (high: DoS via infinite loop - GHSA-m5qc-5hw7-8vg7) - `@babel/helpers` 7.25.6 → 7.28.4 (moderate: RegExp complexity - GHSA-968p-4wvh-cqc8) - `@babel/runtime` 7.25.6 → 7.28.4 (moderate: RegExp complexity - GHSA-968p-4wvh-cqc8) - `js-yaml` → fixed (moderate: prototype pollution - GHSA-mh29-5h37-fv8m) - `brace-expansion` 2.0.1 → 2.0.2 (low: ReDoS - GHSA-v6h2-p8h4-qcjw) - `on-headers` → fixed (low: header manipulation - GHSA-76c9-3jph-rj3q) **Files modified:** - `js/package-lock.json` - `js/react_native/package-lock.json` **Result:** All JS packages (`/js`, `/js/common`, `/js/web`, `/js/node`, `/js/react_native`) now report 0 vulnerabilities. ### Motivation and Context Security maintenance to address dependency vulnerabilities identified by `npm audit`. No breaking changes or code modifications required.
Original prompt > Please create a pull request that runs `npm audit fix` for the JavaScript/TypeScript portion of the repository under the `/js` directory of [microsoft/onnxruntime](https://github.com/microsoft/onnxruntime). > > Requirements: > > 1. **Scope** > - Work only within the `/js` folder and its subpackages (e.g., `js/web`, `js/node`, `js/common`, etc.). > - Do not modify files outside `/js`. > > 2. **Dependency updates** > - Run `npm audit fix` (and, if necessary to fully resolve high/critical issues while staying non-breaking, `npm audit fix --force` on specific subpackages) to address security vulnerabilities. > - Prefer minimal, non-breaking version bumps (patch and minor) that satisfy `npm audit` while keeping semver ranges sensible. > - If any **major** upgrades are required to clear vulnerabilities, handle them cautiously: > - Apply the upgrade only if tests still pass and typings/build setup remain compatible. > - If a major bump would require code changes or creates breaking behavior, **do not** apply it; instead, leave a TODO comment in the PR description summarizing which packages remain vulnerable and why. > > 3. **Validation** > - Run the existing JS-related checks that the repo supports from `/js`, such as: > - `npm test` or package-specific test scripts. > - Any documented lint/build/test commands for JS packages (e.g., `npm run build`, `npm run lint`) where applicable. > - Ensure the updated lockfiles (if present) are consistent, and the project installs cleanly with `npm ci` (or the repo's documented install command) in the `/js` area. > > 4. **Files to update** > - Update `package.json` and lockfiles under `/js` (e.g., `package-lock.json`, `npm-shrinkwrap.json`, or workspace-specific lock files) to reflect the audited dependency tree. > - Do not manually edit `node_modules`; rely on `npm` to manage dependencies and only commit manifest/lockfile changes. > > 5. **Repository conventions** > - Follow this repo's existing conventions for formatting, commit messages, and JS tooling. > - Keep the diff focused on the dependency and lockfile updates plus any absolutely necessary code tweaks to maintain compatibility. > > 6. **Pull request description** > - In the PR body, include: > - A short summary: that `npm audit fix` was run in `/js` to address dependency vulnerabilities. > - A bullet list of notable dependency changes (especially any major version bumps), with packages and old/new versions. > - A brief testing summary (commands run and their results). > - A note about any remaining vulnerabilities that could not be fixed without breaking changes (if applicable), including the affected packages and advisories if available. > > The goal is a clean, minimal PR that improves the security posture of the JS packages under `/js` in `microsoft/onnxruntime` without introducing breaking changes.
*This pull request was created as a result of the following prompt from Copilot chat.* > Please create a pull request that runs `npm audit fix` for the JavaScript/TypeScript portion of the repository under the `/js` directory of [microsoft/onnxruntime](https://github.com/microsoft/onnxruntime). > > Requirements: > > 1. **Scope** > - Work only within the `/js` folder and its subpackages (e.g., `js/web`, `js/node`, `js/common`, etc.). > - Do not modify files outside `/js`. > > 2. **Dependency updates** > - Run `npm audit fix` (and, if necessary to fully resolve high/critical issues while staying non-breaking, `npm audit fix --force` on specific subpackages) to address security vulnerabilities. > - Prefer minimal, non-breaking version bumps (patch and minor) that satisfy `npm audit` while keeping semver ranges sensible. > - If any **major** upgrades are required to clear vulnerabilities, handle them cautiously: > - Apply the upgrade only if tests still pass and typings/build setup remain compatible. > - If a major bump would require code changes or creates breaking behavior, **do not** apply it; instead, leave a TODO comment in the PR description summarizing which packages remain vulnerable and why. > > 3. **Validation** > - Run the existing JS-related checks that the repo supports from `/js`, such as: > - `npm test` or package-specific test scripts. > - Any documented lint/build/test commands for JS packages (e.g., `npm run build`, `npm run lint`) where applicable. > - Ensure the updated lockfiles (if present) are consistent, and the project installs cleanly with `npm ci` (or the repo's documented install command) in the `/js` area. > > 4. **Files to update** > - Update `package.json` and lockfiles under `/js` (e.g., `package-lock.json`, `npm-shrinkwrap.json`, or workspace-specific lock files) to reflect the audited dependency tree. > - Do not manually edit `node_modules`; rely on `npm` to manage dependencies and only commit manifest/lockfile changes. > > 5. **Repository conventions** > - Follow this repo's existing conventions for formatting, commit messages, and JS tooling. > - Keep the diff focused on the dependency and lockfile updates plus any absolutely necessary code tweaks to maintain compatibility. > > 6. **Pull request description** > - In the PR body, include: > - A short summary: that `npm audit fix` was run in `/js` to address dependency vulnerabilities. > - A bullet list of notable dependency changes (especially any major version bumps), with packages and old/new versions. > - A brief testing summary (commands run and their results). > - A note about any remaining vulnerabilities that could not be fixed without breaking changes (if applicable), including the affected packages and advisories if available. > > The goal is a clean, minimal PR that improves the security posture of the JS packages under `/js` in `microsoft/onnxruntime` without introducing breaking changes. --- ✨ Let Copilot coding agent [set things up for you](https://github.com/microsoft/onnxruntime/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: fs-eire <7679871+fs-eire@users.noreply.github.com> --- js/package-lock.json | 144 +++++++++++++++++------------- js/react_native/package-lock.json | 125 +++++++++++++------------- 2 files changed, 142 insertions(+), 127 deletions(-) diff --git a/js/package-lock.json b/js/package-lock.json index 1e9f5cb29fe6c..0fca515b61238 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4,6 +4,7 @@ "requires": true, "packages": { "": { + "name": "js", "license": "MIT", "devDependencies": { "@eslint/compat": "^1.4.0", @@ -3230,6 +3231,27 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "dev": true, + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -3242,6 +3264,32 @@ "node": ">=10.13.0" } }, + "node_modules/glob/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/glob/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/global-agent": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz", @@ -4311,43 +4359,6 @@ "balanced-match": "^1.0.0" } }, - "node_modules/mocha/node_modules/glob": { - "version": "10.4.5", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", - "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", - "dev": true, - "license": "ISC", - "dependencies": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "bin": { - "glob": "dist/esm/bin.mjs" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/mocha/node_modules/glob/node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "license": "ISC", - "dependencies": { - "brace-expansion": "^2.0.1" - }, - "engines": { - "node": ">=16 || 14 >=14.17" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/mocha/node_modules/minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", @@ -8078,6 +8089,40 @@ "get-intrinsic": "^1.2.6" } }, + "glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "dev": true, + "requires": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "dependencies": { + "brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "requires": { + "balanced-match": "^1.0.0" + } + }, + "minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "requires": { + "brace-expansion": "^2.0.1" + } + } + } + }, "glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -8772,31 +8817,6 @@ "balanced-match": "^1.0.0" } }, - "glob": { - "version": "10.4.5", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", - "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", - "dev": true, - "requires": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "dependencies": { - "minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "requires": { - "brace-expansion": "^2.0.1" - } - } - } - }, "minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index e6ed2bdb9e17b..de8d631362db7 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -33,6 +33,7 @@ "version": "1.24.0", "license": "MIT", "devDependencies": { + "globby": "^15.0.0", "typedoc": "^0.25.7" } }, @@ -61,15 +62,15 @@ } }, "node_modules/@babel/code-frame": { - "version": "7.26.2", - "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", - "integrity": "sha512-RJlIHRueQgwWitWgF8OdFYGZX328Ax5BCemNGlqHfplnRT9ESi8JkFlvaVYbS+UubVY6dpv87Fs2u5M29iNFVQ==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", + "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-validator-identifier": "^7.25.9", + "@babel/helper-validator-identifier": "^7.27.1", "js-tokens": "^4.0.0", - "picocolors": "^1.0.0" + "picocolors": "^1.1.1" }, "engines": { "node": ">=6.9.0" @@ -410,9 +411,9 @@ } }, "node_modules/@babel/helper-string-parser": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", - "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", "dev": true, "license": "MIT", "engines": { @@ -420,9 +421,9 @@ } }, "node_modules/@babel/helper-validator-identifier": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", - "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", + "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", "dev": true, "license": "MIT", "engines": { @@ -455,27 +456,27 @@ } }, "node_modules/@babel/helpers": { - "version": "7.25.6", - "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.25.6.tgz", - "integrity": "sha512-Xg0tn4HcfTijTwfDwYlvVCl43V6h4KyVVX2aEm4qdO/PC6L2YvzLHFdmxhoeSA3eslcE6+ZVXHgWwopXYLNq4Q==", + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.28.4.tgz", + "integrity": "sha512-HFN59MmQXGHVyYadKLVumYsA9dBFun/ldYxipEjzA4196jpLZd8UjEEBLkbEkvfYreDqJhZxYAWFPtrfhNpj4w==", "dev": true, "license": "MIT", "dependencies": { - "@babel/template": "^7.25.0", - "@babel/types": "^7.25.6" + "@babel/template": "^7.27.2", + "@babel/types": "^7.28.4" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/parser": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.9.tgz", - "integrity": "sha512-81NWa1njQblgZbQHxWHpxxCzNsa3ZwvFqpUg7P+NNUU6f3UU2jBEg4OlF/J6rl8+PQGh1q6/zWScd001YwcA5A==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.5.tgz", + "integrity": "sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==", "dev": true, "license": "MIT", "dependencies": { - "@babel/types": "^7.26.9" + "@babel/types": "^7.28.5" }, "bin": { "parser": "bin/babel-parser.js" @@ -2114,35 +2115,25 @@ } }, "node_modules/@babel/runtime": { - "version": "7.25.6", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.6.tgz", - "integrity": "sha512-VBj9MYyDb9tuLq7yzqjgzt6Q+IBQLrGZfdjOekyEirZPHxXWoTSGUTMrpsfi58Up73d13NfYLv8HT9vmznjzhQ==", + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz", + "integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==", "dev": true, "license": "MIT", - "dependencies": { - "regenerator-runtime": "^0.14.0" - }, "engines": { "node": ">=6.9.0" } }, - "node_modules/@babel/runtime/node_modules/regenerator-runtime": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", - "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==", - "dev": true, - "license": "MIT" - }, "node_modules/@babel/template": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.26.9.tgz", - "integrity": "sha512-qyRplbeIpNZhmzOysF/wFMuP9sctmh2cFzRAZOn1YapxBsE1i9bJIY586R/WBLfLcmcBlM8ROBiQURnnNy+zfA==", + "version": "7.27.2", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", + "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", "dev": true, "license": "MIT", "dependencies": { - "@babel/code-frame": "^7.26.2", - "@babel/parser": "^7.26.9", - "@babel/types": "^7.26.9" + "@babel/code-frame": "^7.27.1", + "@babel/parser": "^7.27.2", + "@babel/types": "^7.27.1" }, "engines": { "node": ">=6.9.0" @@ -2189,14 +2180,14 @@ "license": "MIT" }, "node_modules/@babel/types": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.9.tgz", - "integrity": "sha512-Y3IR1cRnOxOCDvMmNiym7XpXQ93iGDDPHx+Zj+NM+rg0fBaShfQLkg+hKPaZCEvg5N/LeCo4+Rj/i3FuJsIQaw==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.5.tgz", + "integrity": "sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-string-parser": "^7.25.9", - "@babel/helper-validator-identifier": "^7.25.9" + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.28.5" }, "engines": { "node": ">=6.9.0" @@ -3319,9 +3310,9 @@ } }, "node_modules/babel-plugin-module-resolver/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -3477,7 +3468,9 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.11", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", "dev": true, "license": "MIT", "dependencies": { @@ -3831,9 +3824,9 @@ } }, "node_modules/compression": { - "version": "1.8.0", - "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.0.tgz", - "integrity": "sha512-k6WLKfunuqCYD3t6AsuPGvQWaKwuLLh2/xHNcX4qE+vIfDNXpSqnrhwA7O53R7WVQUnt8dVAIW+YHr7xTgOgGA==", + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz", + "integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==", "dev": true, "license": "MIT", "dependencies": { @@ -3841,7 +3834,7 @@ "compressible": "~2.0.18", "debug": "2.6.9", "negotiator": "~0.6.4", - "on-headers": "~1.0.2", + "on-headers": "~1.1.0", "safe-buffer": "5.2.1", "vary": "~1.1.2" }, @@ -4821,9 +4814,9 @@ } }, "node_modules/image-size": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.0.tgz", - "integrity": "sha512-4S8fwbO6w3GeCVN6OPtA9I5IGKkcDMPcKndtUlpJuCwu7JLjtj7JZpwqLuyY2nrmQT3AWsCJLSKPsc2mPBSl3w==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.1.tgz", + "integrity": "sha512-rH+46sQJ2dlwfjfhCyNx5thzrv+dtmBIhPHk0zgRUukHzZ/kRueTJXoYYsclBaKcSMBWuGbOFXtioLpzTb5euw==", "dev": true, "license": "MIT", "dependencies": { @@ -5250,7 +5243,9 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "3.14.1", + "version": "3.14.2", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.2.tgz", + "integrity": "sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==", "dev": true, "license": "MIT", "dependencies": { @@ -6544,9 +6539,9 @@ } }, "node_modules/on-headers": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", - "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz", + "integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==", "dev": true, "license": "MIT", "engines": { @@ -7130,9 +7125,9 @@ "license": "Python-2.0" }, "node_modules/react-native-builder-bob/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -7203,9 +7198,9 @@ } }, "node_modules/react-native-builder-bob/node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { From df8bf2dfb686de23b2712c073f393eb07834a0f0 Mon Sep 17 00:00:00 2001 From: Wenqin Yang Date: Wed, 26 Nov 2025 06:04:36 +0800 Subject: [PATCH 02/20] [webgpu] Optimize InstanceNormalization by removing redundant transpose (#26626) ### Description This PR optimizes `InstanceNormalization` by removing redundant transpose. Given the implementation of `InstanceNormalization` for `NCHW` is more effiencient, we don't need to add wrapper `Transpose` to make it run in `NHWC`, which helps use to elide redundant transpose and improve performance. Testing on Lunar Lake shows about `~60%` performance improvement in `InstanceNormalization` operations. #### `InstanceNormalization` OP benchmark The input tensor shape: `(1,32,1048576)` The scale tensor shape: `(32)` The B tensor shape: `(32)` | time cost (ms) | baseline | opt | diff | | ---------------- | -------- | ---- | ---- | | Lunar Lake | 82.6 | 34.2 | 58% | #### Model benchmark | time cost (ms) | baseline | opt | diff | | ---------------- | -------- | ---- | ---- | | sd-turbo-vae-decoder-fp16-demo | 2437.6 | 1835.9 | 25% | ### Motivation and Context Please see above --- .../providers/webgpu/webgpu_execution_provider.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index e0b84fef51f1f..395517e068452 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -935,13 +935,14 @@ std::unique_ptr WebGpuExecutionProvider::GetEx std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, std::string_view node_op_type, DataLayout target_data_layout) const { - if (target_data_layout != DataLayout::NHWC) { - return std::nullopt; - } - // NHWC for Resize operator is not implemented on kWebGpuExecutionProvider if (node_domain == kOnnxDomain && node_op_type == "Resize") { - return false; + return target_data_layout != DataLayout::NHWC; + } + + // WebGPU perfer NCHW for InstanceNormalization due to a better performance + if (node_domain == kOnnxDomain && node_op_type == "InstanceNormalization") { + return target_data_layout != DataLayout::NHWC; } return std::nullopt; From 5c28c7e41bd52aaabe8b3ec6e50cdbcf0f230894 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 25 Nov 2025 19:18:12 -0800 Subject: [PATCH 03/20] [webgpu] refactor a few "context" classes (#26602) ### Description This PR refactors a few "context" classes to make it clearer and support new features. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> --- .../core/providers/webgpu/allocator.cc | 2 +- onnxruntime/core/providers/webgpu/allocator.h | 5 + .../core/providers/webgpu/compute_context.cc | 23 ++-- .../core/providers/webgpu/compute_context.h | 103 ++++++++++++------ onnxruntime/core/providers/webgpu/nn/conv.cc | 40 +++++++ onnxruntime/core/providers/webgpu/nn/conv.h | 7 ++ .../core/providers/webgpu/tensor/transpose.cc | 2 +- .../core/providers/webgpu/tensor/transpose.h | 2 +- .../core/providers/webgpu/webgpu_context.cc | 18 ++- .../core/providers/webgpu/webgpu_context.h | 21 ++-- .../webgpu/webgpu_execution_provider.cc | 3 +- .../core/providers/webgpu/webgpu_kernel.cc | 47 ++++++-- .../core/providers/webgpu/webgpu_kernel.h | 33 ++++++ .../core/providers/webgpu/webgpu_utils.cc | 15 +-- .../core/providers/webgpu/webgpu_utils.h | 5 +- 15 files changed, 240 insertions(+), 86 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index b3eb4b5061423..3e1b87821fe2f 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -13,7 +13,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), + WebGpuDevice, OrtMemTypeDefault)), buffer_manager_{buffer_manager}, mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 7c38b4557e078..74b3d669fcf3b 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -11,6 +11,11 @@ namespace webgpu { class BufferManager; +inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, + OrtDevice::MemType::DEFAULT, + OrtDevice::VendorIds::NONE, + 0}; + class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index ebe71c6ccfacd..d1a2011c8e191 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -6,22 +6,25 @@ namespace onnxruntime { namespace webgpu { -ComputeContext::ComputeContext(OpKernelContext& kernel_context, - const OpKernel& op_kernel, - const WebGpuExecutionProvider& ep, - WebGpuContext& webgpu_context) + +ComputeContextBase::ComputeContextBase(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel) : webgpu_context_{webgpu_context}, - kernel_context_{kernel_context}, - op_kernel_{op_kernel}, - ep_{ep} { + ep_{ep}, + op_kernel_{op_kernel} { } -const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const ComputeContext& context) { +const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(const ComputeContextBase& context) { return context.ep_.BufferManager(); } -const SplitKConfig& ComputeContext::GetSplitKConfig() { - return webgpu_context_.GetSplitKConfig(); +ComputeContext::ComputeContext(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel, + OpKernelContext& kernel_context) + : ComputeContextBase(webgpu_context, ep, op_kernel), + kernel_context_{kernel_context} { } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index ed16f2f0a1345..fdf89854469d6 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -24,7 +24,13 @@ namespace webgpu { class WebGpuContext; class BufferManager; -class ComputeContext final { +// +// Class ComputeContextBase is designed to provide basic context information +// for running a compute shader program. +// +// An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created. +// +class ComputeContextBase { public: // Nested accessor class to provide controlled access to BufferManager class BufferManagerAccessor { @@ -34,18 +40,31 @@ class ComputeContext final { friend class WebGpuContext; private: - static const webgpu::BufferManager& Get(const ComputeContext& context); + static const webgpu::BufferManager& Get(const ComputeContextBase& context); }; - ComputeContext(OpKernelContext& kernel_context, - const OpKernel& op_kernel, - const WebGpuExecutionProvider& ep, - WebGpuContext& webgpu_context); + ComputeContextBase(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel); - ~ComputeContext() = default; + ~ComputeContextBase() = default; + + // + // Get the node name. + // + inline decltype(auto) NodeName() const { + return op_kernel_.Node().Name(); + } + + // + // Get the operator type. + // + inline decltype(auto) OpType() const { + return op_kernel_.Node().OpType(); + } // - // Get various information from the context. + // Get various information from the WebGPU context. // inline const wgpu::AdapterInfo& AdapterInfo() const { @@ -57,9 +76,6 @@ class ComputeContext final { inline bool HasFeature(wgpu::FeatureName feature) const { return webgpu_context_.DeviceHasFeature(feature); } - inline bool IsGraphCaptureEnabled() const { - return ep_.IsGraphCaptureEnabled(); - } #if !defined(__wasm__) inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const { return webgpu_context_.SubgroupMatrixConfigs(); @@ -67,17 +83,57 @@ class ComputeContext final { #endif // - // Get the kernel context. + // Get Split-K configuration. // - inline OpKernelContext& KernelContext() { - return kernel_context_; + inline const SplitKConfig& GetSplitKConfig() const { + return webgpu_context_.GetSplitKConfig(); + } + + // + // Get whether graph capture is enabled. + // + inline bool IsGraphCaptureEnabled() const { + return ep_.IsGraphCaptureEnabled(); } // // Get the logger. // inline const logging::Logger& Logger() const { - return kernel_context_.Logger(); + return *ep_.GetLogger(); + } + + // + // Run a compute shader program. + // + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } + + protected: + WebGpuContext& webgpu_context_; + const WebGpuExecutionProvider& ep_; + const OpKernel& op_kernel_; +}; + +// +// Class ComputeContext provides all information a `ComputeContextBase` provides, and also +// access to `OpKernelContext` for input and output tensors. +// +class ComputeContext final : public ComputeContextBase { + public: + ComputeContext(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel, + OpKernelContext& kernel_context); + + ~ComputeContext() = default; + + // + // Get the kernel context. + // + inline OpKernelContext& KernelContext() { + return kernel_context_; } // @@ -145,25 +201,8 @@ class ComputeContext final { return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst); } - // - // Run a compute shader program. - // - inline Status RunProgram(const ProgramBase& program) { - return webgpu_context_.Run(*this, program); - } - - // - // Get Split-K configuration. - // - // `split_k_config_` won't be initialized until the first call to this method. - // - const SplitKConfig& GetSplitKConfig(); - private: - WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; - const OpKernel& op_kernel_; - const WebGpuExecutionProvider& ep_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 77fa46cb87518..4fff736fd2f32 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -216,6 +216,46 @@ Status Conv::ComputeInternal(ComputeContext& context return context.RunProgram(conv2d_mm_program); } +template +Status Conv::PrePackInternal(ComputeContextBase& /* context */, + const Tensor& tensor, + int input_idx, + AllocatorPtr /* alloc */, + /*out*/ bool& is_packed) { + is_packed = false; + + if constexpr (is_channels_last) { + if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) { + // only deal with 4D NHWC weights + + // TODO: implement weight transpose for pre-pack here + // Conv::ComputeInternal() should be updated to reflect the change: + // - if the initializer is packed, `context.Input(1)` will be nullptr. + // - in this case, use `transposed_kernel_` instead. + + // // Step.1 - calculate transposed weight shape + // TensorShape transposed_kernel_shape{tensor.Shape()[2], + // tensor.Shape()[3], + // tensor.Shape()[1], + // tensor.Shape()[0]}; + + // // Step.2 - create transposed weight tensor + // transposed_kernel_ = std::make_unique(tensor.DataType(), transposed_kernel_shape, alloc); + + // // Step.3 - do transpose + // size_t perm[] = {2, 3, 1, 0}; + // ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, + // perm, + // tensor, + // *transposed_kernel_)); + + // is_packed = true; // set this flag to true so that ORT will release the initializer tensor + } + } + + return Status::OK(); +} + // Explicit template instantiation for FusedConv template class Conv; template class Conv; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h index cafaa272c0613..5bf94a459a44a 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.h +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -23,9 +23,16 @@ class Conv : public WebGpuKernel { } Status ComputeInternal(ComputeContext& context) const override; + Status PrePackInternal(ComputeContextBase& context, + const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed) override; + protected: ConvAttributes conv_attrs_; Activation activation_; + std::unique_ptr transposed_kernel_; // should only have value when `is_initializer` AND `is_4D` AND `is_NHWC` }; Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index cec321d0da80e..5415d4a5ead5b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -108,7 +108,7 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, +Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output) { const auto& input_shape = input.Shape(); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index b62a419fa12bc..5e9ccc6750cd6 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase { Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { } Status ComputeInternal(ComputeContext& context) const override; - static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output); + static Status DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output); constexpr static uint32_t TILE_SIZE = 16; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 28decb076951e..b8d5adc421124 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -147,6 +147,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // create program manager program_mgr_ = std::make_unique(*this); + // create split-k config + split_k_config_ = std::make_unique(adapter_info_); + // set query type #if !defined(__wasm__) if (DeviceHasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { @@ -178,7 +181,7 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } -Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { +Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); @@ -288,8 +291,8 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch); if (is_profiling_) { - PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(), - context.KernelContext().GetOpType(), + PendingKernelInfo pending_kernel_info(context.NodeName(), + context.OpType(), program.Name(), key, inputs, @@ -442,7 +445,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; WGPUBuffer uniform_buffer = nullptr; - const webgpu::BufferManager& buffer_mgr = ComputeContext::BufferManagerAccessor::Get(context); + const webgpu::BufferManager& buffer_mgr = ComputeContextBase::BufferManagerAccessor::Get(context); if (uniform_buffer_total_size > 0) { std::vector uniform_data_buffer(uniform_buffer_total_size); @@ -910,13 +913,6 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; std::once_flag WebGpuContextFactory::init_default_flag_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index bd7dae75f2e2d..84dfb47ef4687 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,7 +5,6 @@ #include #include -#include #include "core/providers/webgpu/webgpu_external_header.h" @@ -23,7 +22,7 @@ class Tensor; namespace webgpu { class WebGpuContext; -class ComputeContext; +class ComputeContextBase; class ProgramBase; // Definition for CapturedCommandInfo in the webgpu namespace @@ -152,6 +151,13 @@ class WebGpuContext final { return validation_mode_; } + // + // Get Split-K configuration. + // + const SplitKConfig& GetSplitKConfig() const { + return *split_k_config_; + } + void StartProfiling(); void CollectProfilingData(profiling::Events& events); void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); @@ -170,16 +176,9 @@ class WebGpuContext final { // Status PopErrorScope(); - Status Run(ComputeContext& context, const ProgramBase& program); + Status Run(ComputeContextBase& context, const ProgramBase& program); void OnRunEnd(); - // - // Get Split-K configuration. - // - // `split_k_config_` won't be initialized until the first call to this method. - // - const SplitKConfig& GetSplitKConfig(); - private: enum class TimestampQueryType { None = 0, @@ -277,7 +276,7 @@ class WebGpuContext final { uint32_t num_pending_dispatches_ = 0; const uint32_t max_num_pending_dispatches_ = 16; - std::optional split_k_config_; + std::unique_ptr split_k_config_; // profiling TimestampQueryType query_type_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 395517e068452..6b764d51bcf75 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -794,8 +794,7 @@ using namespace webgpu; WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, WebGpuContext& context, WebGpuExecutionProviderConfig&& config) - : IExecutionProvider{kWebGpuExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)}, + : IExecutionProvider{kWebGpuExecutionProvider, WebGpuDevice}, context_id_{context_id}, context_{context}, preferred_data_layout_{config.data_layout}, diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc index 8d6ae6caeaf83..ea38e9415e1fe 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc @@ -11,25 +11,58 @@ namespace webgpu { WebGpuKernel::WebGpuKernel(const OpKernelInfo& info) : OpKernel(info), - ep_(*static_cast(info.GetExecutionProvider())) { + ep_(*static_cast(info.GetExecutionProvider())), + webgpu_context_(WebGpuContextFactory::GetContext(ep_.GetDeviceId())) { } Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const { - WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId()); - ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context}; + ComputeContext context{webgpu_context_, + ep_, + *this, + *p_op_kernel_context}; - if (webgpu_context.ValidationMode() >= ValidationMode::Full) { - webgpu_context.PushErrorScope(); + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + webgpu_context_.PushErrorScope(); } Status s = ComputeInternal(context); - if (webgpu_context.ValidationMode() >= ValidationMode::Full) { - ORT_RETURN_IF_ERROR(webgpu_context.PopErrorScope()); + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); } return s; } +Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { + ComputeContextBase context{webgpu_context_, ep_, *this}; + + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + webgpu_context_.PushErrorScope(); + } + + // Currently, ORT does not allow using prepacked weights in non-CPU EPs. + // So we do not pass prepacked_weights to PrePackInternal. + // Kernel implementation that supports prepacking should manage its own storage. + + Status s = PrePackInternal(context, tensor, input_idx, alloc, is_packed); + + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); + } + + return s; +} + +Status WebGpuKernel::PrePackInternal(ComputeContextBase& /*context*/, + const Tensor& /*tensor*/, + int /*input_idx*/, + AllocatorPtr /*alloc*/, + /*out*/ bool& is_packed) { + is_packed = false; + return Status::OK(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 3c750e305421c..2c57991c6ee35 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -23,8 +23,41 @@ class WebGpuKernel : public OpKernel { virtual Status ComputeInternal(ComputeContext& context) const = 0; + // Overrides OpKernel::PrePack to handle constant tensor pre-processing for WebGPU kernels. + // This method creates a ComputeContextBase and delegates to PrePackInternal. + // + // NOTE: Currently, ORT does not allow using prepacked weights in non-CPU EPs, so the + // prepacked_weights parameter is not passed to PrePackInternal. Kernel implementations + // that support prepacking should manage their own storage. + Status PrePack(const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + // Virtual method that allows derived kernels to pre-process constant tensors during initialization. + // + // This method is called during kernel initialization when constant tensors are available, + // allowing kernels to perform operations like tensor transposition or format conversion + // before the first Compute call. + // + // @param context The WebGPU compute context base providing access to the execution environment. + // @param tensor The constant tensor to potentially pre-process. + // @param input_idx The index of this input in the kernel's input list. + // @param alloc The allocator to use for any new tensor allocations. + // @param is_packed Output parameter. Set to true if the tensor was pre-packed/processed, + // false otherwise. The default implementation sets this to false. + // + // @return Status::OK() on success, or an error status on failure. + virtual Status PrePackInternal(ComputeContextBase& context, + const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed); + private: const WebGpuExecutionProvider& ep_; + WebGpuContext& webgpu_context_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 568d29a96cb88..5fd24b2bff037 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -21,27 +21,24 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components return TensorShape(shape_vector); } -SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) { - SplitKConfig config = {}; - +SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) { if (adapter_info.vendor == std::string_view{"intel"}) { if (adapter_info.architecture == std::string_view{"xe-2lpg"} || adapter_info.architecture == std::string_view{"xe-2hpg"} || adapter_info.architecture == std::string_view{"xe-lpg"} || adapter_info.architecture == std::string_view{"gen-12hp"}) { - config.enable_split_k_ = true; + enable_split_k_ = true; // Below thresholds are only verified on the above Intel GPUs without any regressions. The // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more // atomic calls for each output value. - config.split_dim_inner_ = 256; - config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; - config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; - config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; + split_dim_inner_ = 256; + min_dim_inner_with_split_k_ = split_dim_inner_ * 2; + max_dim_inner_with_split_k_ = split_dim_inner_ * 9; + max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; } } - return config; } bool SplitKConfig::UseSplitK( diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index d45b9bf4dd119..7d5ab5fea8006 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -91,9 +91,12 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; } +/** + * Configuration for Split-K optimization (Conv|MatMul). + */ class SplitKConfig { public: - static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info); + explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, From 454c61b7e36c8b080e2cf00c86e18dcf8f2d98ab Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 Nov 2025 04:58:37 +0000 Subject: [PATCH 04/20] Bump actions/checkout from 5 to 6 (#26641) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [actions/checkout](https://github.com/actions/checkout) from 5 to 6.
Release notes

Sourced from actions/checkout's releases.

v6.0.0

What's Changed

Full Changelog: https://github.com/actions/checkout/compare/v5.0.0...v6.0.0

v6-beta

What's Changed

Updated persist-credentials to store the credentials under $RUNNER_TEMP instead of directly in the local git config.

This requires a minimum Actions Runner version of v2.329.0 to access the persisted credentials for Docker container action scenarios.

v5.0.1

What's Changed

Full Changelog: https://github.com/actions/checkout/compare/v5...v5.0.1

Changelog

Sourced from actions/checkout's changelog.

Changelog

V6.0.0

V5.0.1

V5.0.0

V4.3.1

V4.3.0

v4.2.2

v4.2.1

v4.2.0

v4.1.7

v4.1.6

v4.1.5

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/checkout&package-manager=github_actions&previous-version=5&new-version=6)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/android.yml | 6 +++--- .github/workflows/cffconvert.yml | 2 +- .github/workflows/codeql.yml | 2 +- .../workflows/gradle-wrapper-validation.yml | 2 +- .github/workflows/ios.yml | 2 +- .github/workflows/lint.yml | 8 ++++---- .../linux-wasm-ci-build-and-test-workflow.yml | 2 +- .github/workflows/linux_cuda_ci.yml | 2 +- .github/workflows/linux_minimal_build.yml | 20 +++++++++---------- .github/workflows/linux_tensorrt_ci.yml | 2 +- .github/workflows/mac.yml | 4 ++-- .../macos-ci-build-and-test-workflow.yml | 2 +- .github/workflows/pr_checks.yml | 2 +- .github/workflows/publish-c-apidocs.yml | 2 +- .github/workflows/publish-csharp-apidocs.yml | 2 +- .github/workflows/publish-java-apidocs.yml | 2 +- .github/workflows/publish-js-apidocs.yml | 2 +- .../workflows/publish-objectivec-apidocs.yml | 2 +- .github/workflows/publish-python-apidocs.yml | 2 +- .github/workflows/react_native.yml | 8 ++++---- .github/workflows/reusable_linux_build.yml | 2 +- .github/workflows/web.yml | 2 +- .github/workflows/windows-web-ci-workflow.yml | 2 +- .github/workflows/windows_build_x64_asan.yml | 2 +- .github/workflows/windows_cuda.yml | 4 ++-- .github/workflows/windows_dml.yml | 2 +- .github/workflows/windows_openvino.yml | 2 +- .github/workflows/windows_qnn_x64.yml | 2 +- .github/workflows/windows_tensorrt.yml | 4 ++-- .github/workflows/windows_webgpu.yml | 6 +++--- .../windows_x64_debug_build_x64_debug.yml | 2 +- .../windows_x64_release_build_x64_release.yml | 2 +- ...build_x64_release_ep_generic_interface.yml | 2 +- ..._x64_release_vitisai_build_x64_release.yml | 2 +- .../workflows/windows_x64_release_xnnpack.yml | 2 +- .github/workflows/windows_x86.yml | 2 +- 36 files changed, 58 insertions(+), 58 deletions(-) diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 7f7ff74959d52..f12eadc2ce794 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -27,7 +27,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -112,7 +112,7 @@ jobs: android_nnapi_ep: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Use jdk 17 uses: actions/setup-java@v5 @@ -187,7 +187,7 @@ jobs: name: Android CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Use jdk 17 uses: actions/setup-java@v5 diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index 30f832f67c5ee..ddf4a52a0ccb0 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -12,7 +12,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Check out a copy of the repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@2.0.0 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index d33e4d923a0bc..1db84400c272a 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 04177b11e9c30..d8f13d13d3f88 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -15,7 +15,7 @@ jobs: name: "Validation" runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: gradle/actions/wrapper-validation@v5 concurrency: group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} diff --git a/.github/workflows/ios.yml b/.github/workflows/ios.yml index 0d2046b980783..ed572aa339ce9 100644 --- a/.github/workflows/ios.yml +++ b/.github/workflows/ios.yml @@ -20,7 +20,7 @@ jobs: runs-on: macos-14 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5aaab5f8e1a10..5c618dc5787a5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: name: Optional Lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: misspell # Check spellings as well uses: reviewdog/action-misspell@v1 with: @@ -42,7 +42,7 @@ jobs: contents: read security-events: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 with: @@ -87,7 +87,7 @@ jobs: name: Optional Lint C++ runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Update PATH run: | echo "$HOME/.local/bin" >> "$GITHUB_PATH" @@ -116,7 +116,7 @@ jobs: name: Lint JavaScript runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-node@v6 with: node-version: 20 diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 2370c631b7a7a..5763b9c39bcc6 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -49,7 +49,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: recursive diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index 886705471b7de..e7e3be8c5f9ed 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -48,7 +48,7 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index af86975ee6cdc..4d9579a746892 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -28,7 +28,7 @@ jobs: packages: write steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -65,7 +65,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -122,7 +122,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -156,7 +156,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -188,7 +188,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -222,7 +222,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -286,7 +286,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -363,7 +363,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -430,7 +430,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -505,7 +505,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 0e26576829e94..47b7c1ba7e889 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -48,7 +48,7 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 # --- Build the Docker image needed for testing --- - name: Build Docker Image for Testing diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index e545406d8d20f..8ba87bc1f731c 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -76,7 +76,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' @@ -124,7 +124,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index 329584c68d7d1..8e1d0264496f6 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -75,7 +75,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index abe627f4ff7bc..7ca330742f69b 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -24,7 +24,7 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 with: diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index 25b7899584bbf..d9fb72271967f 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate C/C++ API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install doxygen and dependencies run: | sudo apt update diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index 34b9c1af9552f..dd55bbd917337 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -24,7 +24,7 @@ jobs: env: DOCFXVERSION: 2.62.2 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install DocFX run: | dotnet tool update -g docfx diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 656d0627ed17d..81defeae518a3 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Java docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Set up JDK 11 uses: actions/setup-java@v5 with: diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index e71d3b3c57a4b..9da78d7d9ed9c 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate JS API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Node.js uses: actions/setup-node@v6 with: diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index 983d3d478a49d..a73b62eba6050 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Objective-C API docs runs-on: macos-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index 389d1683fb1ff..e35e6a04adbef 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate Python API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install tools run: | sudo apt-get update diff --git a/.github/workflows/react_native.yml b/.github/workflows/react_native.yml index 343186b1aec8c..4a56dfbd35406 100644 --- a/.github/workflows/react_native.yml +++ b/.github/workflows/react_native.yml @@ -20,7 +20,7 @@ jobs: aar_path: ${{ runner.temp }}/.artifacts steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -75,7 +75,7 @@ jobs: run: echo "ANDROID_AVD_HOME=${{ runner.temp }}/android-avd" >> $GITHUB_ENV - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Use Python 3.12 uses: actions/setup-python@v6 @@ -175,7 +175,7 @@ jobs: timeout-minutes: 120 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Use Xcode 15.3.0 run: sudo xcode-select --switch /Applications/Xcode_15.3.0.app/Contents/Developer @@ -218,7 +218,7 @@ jobs: timeout-minutes: 90 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Download iOS pod artifact uses: actions/download-artifact@v6 diff --git a/.github/workflows/reusable_linux_build.yml b/.github/workflows/reusable_linux_build.yml index 795e35b06bfb0..f0da87647b8b0 100644 --- a/.github/workflows/reusable_linux_build.yml +++ b/.github/workflows/reusable_linux_build.yml @@ -75,7 +75,7 @@ jobs: id-token: write steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Set up Python ${{ inputs.python_version }} if: inputs.architecture != 'arm64' diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml index 016feab5e0d94..6ae25ccc0bf3e 100644 --- a/.github/workflows/web.yml +++ b/.github/workflows/web.yml @@ -22,7 +22,7 @@ jobs: commit_sha: ${{ steps.extract_commit.outputs.commit_sha }} steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: true diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index eee98332056f6..c16ce6eb222eb 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_build_x64_asan.yml b/.github/workflows/windows_build_x64_asan.yml index 05fd4acd4de9a..ac5f08717155f 100644 --- a/.github/workflows/windows_build_x64_asan.yml +++ b/.github/workflows/windows_build_x64_asan.yml @@ -19,7 +19,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index fd5b65eb039a3..5d6e9b1da31a2 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -21,7 +21,7 @@ jobs: name: Windows GPU CUDA CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' @@ -152,7 +152,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index e8ee7751348b4..0abf6b650f986 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -27,7 +27,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 # Fetch all history for all tags and branches submodules: 'none' diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index b608c0879aa45..537ff1fb00071 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml index 4f0b50e65df6e..f6176164354bb 100644 --- a/.github/workflows/windows_qnn_x64.yml +++ b/.github/workflows/windows_qnn_x64.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index 229efb01f0018..4a564a3b1cb36 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -21,7 +21,7 @@ jobs: name: Windows GPU TensorRT CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' @@ -157,7 +157,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 899a8b66eac7a..f729cda5ea576 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -34,7 +34,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none @@ -156,7 +156,7 @@ jobs: timeout-minutes: 300 steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none @@ -209,7 +209,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index d62c7130e0ebb..385d03c1a6705 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index a2991bb0f1131..ee045b70b6efa 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index bb6c5035b0dce..25dfc41e6922c 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index 4378231338673..e738db262f3a2 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index b453cd570ac05..5672e4043c624 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index d20778d56f60b..381d9dda5cd42 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false From b33c91ab77d6d4088d061e41a0addb20879abe6b Mon Sep 17 00:00:00 2001 From: xieofxie Date: Wed, 26 Nov 2025 13:37:29 +0800 Subject: [PATCH 05/20] add LogEvaluationStart for ReplayGraph (#26645) ### Description add LogEvaluationStart for ReplayGraph to match LogEvaluationStop ### Motivation and Context So by using ETW, could capture run time correctly Co-authored-by: hualxie --- onnxruntime/core/session/inference_session.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4d4dea9cb444c..ab3932e7abfb4 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2943,6 +2943,8 @@ Status InferenceSession::Run(const RunOptions& run_options, << cached_execution_provider_for_graph_replay_.Type() << " CUDA Graph for this model with tag: " << run_options.run_tag << " with graph annotation id: " << graph_annotation_id; + // log evaluation start to trace logging provider + env.GetTelemetryProvider().LogEvaluationStart(session_id_); ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id)); } else { InlinedVector exec_providers_to_stop; From 7845ea8442993def9138f7b53cd5228e0e42b11e Mon Sep 17 00:00:00 2001 From: xieofxie Date: Wed, 26 Nov 2025 13:46:48 +0800 Subject: [PATCH 06/20] add LogCompileModel to mark the session usage (#26646) ### Description add LogCompileModel to mark the session usage as Compile because that session will not be used for inference We could also use it to log compile model parameters if needed ### Motivation and Context We are building a profiling tool for WinML and we want to differentiate Compile session and inference session. I think there are two ways to do it but I don't know which is better https://github.com/microsoft/onnxruntime/pull/26646 https://github.com/microsoft/onnxruntime/pull/26647 --------- Co-authored-by: hualxie --- onnxruntime/core/platform/telemetry.cc | 4 ++++ onnxruntime/core/platform/telemetry.h | 2 ++ onnxruntime/core/platform/windows/telemetry.cc | 14 ++++++++++++++ onnxruntime/core/platform/windows/telemetry.h | 2 ++ onnxruntime/core/session/utils.cc | 1 + 5 files changed, 23 insertions(+) diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 6cbbdd4e0a7ef..1eb03af3befa4 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -81,6 +81,10 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(captureState); } +void Telemetry::LogCompileModel(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); +} + void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { ORT_UNUSED_PARAMETER(session_id); diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index b60345e1b8a80..9c2859f7634b6 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -66,6 +66,8 @@ class Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const; + virtual void LogCompileModel(uint32_t session_id) const; + virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 2e5d334856278..693e265af46b1 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -334,6 +334,20 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio } } +void WindowsTelemetry::LogCompileModel(uint32_t session_id) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "CompileModel", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId")); +} + void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { if (global_register_count_ == 0 || enabled_ == false) diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 261d14a7fed8c..044feec071223 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -59,6 +59,8 @@ class WindowsTelemetry : public Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const override; + void LogCompileModel(uint32_t session_id) const override; + void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 6189e6ca7f012..4cb21b80109c8 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -404,6 +404,7 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model session))); } + Env::Default().GetTelemetryProvider().LogCompileModel(session->GetCurrentSessionId()); ORT_RETURN_IF_ERROR(ToStatusAndRelease(InitializeSession(session_options, *session))); return Status::OK(); } From 55bfa30465cbae1871d6a0355db1653a830f5f8b Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 27 Nov 2025 00:04:46 +0800 Subject: [PATCH 07/20] [webgpu] Fix bug introduced by RoE (#26661) Fix bug introduced by #26563 which used the wrong condition by accident and results incorrect result in graph capture mode. --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 606dbfde15c2c..3f5b28f92f55d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -421,7 +421,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co indirect_buffer_ptr, tile_size)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr)); } if (parameters.sequence_length_ > 1) { From f02a6407687ec8c8982a15249809b93918cf20ff Mon Sep 17 00:00:00 2001 From: qti-hungjuiw Date: Thu, 27 Nov 2025 01:23:04 +0800 Subject: [PATCH 08/20] [QNN-EP] Enable verbose and artifacts saving in onnxruntime_provider_test.exe (#26396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description - The change allows users to better debug unit tests by adding the following environment variables: - `QNN_DUMP_ONNX`: Dump input onnx model - `QNN_DUMP_JSON`: Dump json qnn graph with provider_option `dump_json_qnn_graph` - `QNN_DUMP_DLC`: Dump dlc with provider_option `qnn_ir_backend_path` - `QNN_VERBOSE`: Use the log level `ORT_LOGGING_LEVEL_VERBOSE` - Developers can use the environment variables above to save the artifacts of QNN-EP testcases to a directory named with `_` ``` . ├── QnnCPUBackendTests_BatchNorm2D_fp32 # RunQnnModelTest │ ├── dumped_f32_model.onnx # float32 ONNX model │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc │ └── QNNExecutionProvider_QNN_XXXX_X_X.json ├── QnnHTPBackendTests_BatchNorm_FP16 # TestFp16ModelAccuracy │ ├── dumped_f16_model.onnx # float16 ONNX model │ ├── dumped_f32_model.onnx # float32 ONNX model │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc │ └── QNNExecutionProvider_QNN_XXXX_X_X.json └── QnnHTPBackendTests_BatchNorm2D_U8U8S32 # TestQDQModelAccuracy ├── dumped_f32_model.onnx # float32 ONNX model ├── dumped_qdq_model.onnx # QDQ ONNX model ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc └── QNNExecutionProvider_QNN_XXXX_X_X.json # All artifact files are placed under the current working directory from which the test binary is invoked. ``` ### Motivation and Context - The Json qnn graph/dlc are helpful for backend to debug performance/accuracy issues - By comparing the onnx and Json qnn graph/dlc, we can locate the issue about graph manipulation. --- .../qnn/builder/opbuilder/base_op_builder.cc | 3 + .../core/providers/qnn/builder/qnn_def.cc | 4 + .../core/providers/qnn/builder/qnn_def.h | 2 + onnxruntime/test/providers/qnn/README.md | 70 +++++++++ .../test/providers/qnn/qnn_test_utils.cc | 60 +++++++ .../test/providers/qnn/qnn_test_utils.h | 147 +++++++++++++++++- 6 files changed, 278 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/test/providers/qnn/README.md diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 4d183b95bd938..0bb3accb4d754 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -76,6 +76,9 @@ Status BaseOpBuilder::ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper, return CheckHtpDataTypes(input_qnn_dtypes, output_qnn_dtypes); } else if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) { return CheckGpuDataTypes(input_qnn_dtypes, output_qnn_dtypes); + } else if (IsIrBackend(qnn_model_wrapper.GetQnnBackendType())) { + // TODO: CheckIrDataTypes + return Status::OK(); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Only support backend: CPU, HTP and GPU"); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index f3d81d7d2fdd7..9f28e2609faa1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -574,6 +574,10 @@ bool QnnOpConfigWrapper::CreateQnnGraphOp(const QNN_INTERFACE_VER_TYPE& qnn_inte return true; } +bool IsIrBackend(QnnBackendType backend_type) { + return backend_type == QnnBackendType::SERIALIZER; +} + bool IsNpuBackend(QnnBackendType backend_type) { return backend_type == QnnBackendType::HTP || backend_type == QnnBackendType::DSP; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 42f4d7bb60f34..77508f3934a20 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -96,6 +96,8 @@ enum class QnnBackendType : uint8_t { SERIALIZER, }; +bool IsIrBackend(QnnBackendType backend_type); + bool IsCpuBackend(QnnBackendType backend_type); bool IsNpuBackend(QnnBackendType backend_type); diff --git a/onnxruntime/test/providers/qnn/README.md b/onnxruntime/test/providers/qnn/README.md new file mode 100644 index 0000000000000..c3d0c720a1aa4 --- /dev/null +++ b/onnxruntime/test/providers/qnn/README.md @@ -0,0 +1,70 @@ +# ONNX Runtime QNN Execution Provider Tests +## Overview +1. The `onnxruntime/test/providers/qnn` directory contains integration tests for the Qualcomm Neural Network (QNN) execution provider. +2. Most testcases run an ONNX model through the QNN-EP, then verifies the inference result against the one on CPU-EP + +## Building the Tests +The tests are built as part of the regular ONNX Runtime build. After a successful build you will have an executable named +- onnxruntime_provider_test.exe (Windows) +- onnxruntime_provider_test (Linux/macOS) + +## Running the Tests +1. QNN supports several backends. You can use the standard Google‑Test syntax for filtering: + - `onnxruntime_provider_test.exe --gtest_filter=QnnCPUBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnHTPBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnGPUBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnIRBackendTests.*` +2. Saving Test Artifacts + - For debugging it is often helpful to keep the intermediate files that the tests generate. The following environment + variables are recognized by the test binary: + - `QNN_DUMP_ONNX`: Saves the input ONNX model used for the test + - `QNN_DUMP_JSON`: Save json qnn graph with provider_option `dump_json_qnn_graph` + - `QNN_DUMP_DLC`: Saves the compiled QNN DLC file by specifying the provider_option `backend_path` to `QnnIr.dll` + - The artifacts will be saved to a directory named with `_` + ``` + . + ├── QnnCPUBackendTests_BatchNorm2D_fp32 # RunQnnModelTest + │ ├── dumped_f32_model.onnx # float32 ONNX model + │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + │ └── QNNExecutionProvider_QNN_XXXX_X_X.json + ├── QnnHTPBackendTests_BatchNorm_FP16 # TestFp16ModelAccuracy + │ ├── dumped_f16_model.onnx # float16 ONNX model + │ ├── dumped_f32_model.onnx # float32 ONNX model + │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + │ └── QNNExecutionProvider_QNN_XXXX_X_X.json + └── QnnHTPBackendTests_BatchNorm2D_U8U8S32 # TestQDQModelAccuracy + ├── dumped_f32_model.onnx # float32 ONNX model + ├── dumped_qdq_model.onnx # QDQ ONNX model + ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + └── QNNExecutionProvider_QNN_XXXX_X_X.json + + # All artifact files are placed under the current working directory from which the test binary is invoked. + ``` +3. Verbose + - `QNN_VERBOSE`: Sets the ONNX Runtime log level to `ORT_LOGGING_LEVEL_VERBOSE` + +4. You can enable any combination of these environment variables, for example: + - On Linux/macOS + ```bash + export QNN_DUMP_ONNX=1 + export QNN_DUMP_JSON=1 + export QNN_DUMP_DLC=1 + export QNN_VERBOSE=1 + ``` + - On Windows + ```cmd + set QNN_DUMP_ONNX=1 + set QNN_DUMP_JSON=1 + set QNN_DUMP_DLC=1 + set QNN_VERBOSE=1 + ``` + ```ps1 + $Env:QNN_DUMP_ONNX = "1" + $Env:QNN_DUMP_JSON = "1" + $Env:QNN_DUMP_DLC = "1" + $Env:QNN_VERBOSE = "1" + ``` + +# Note +- An issue on QNN backends can prevent the test artifacts from being successfully saved. +- The `onnxruntime_provider_test.exe` does not automatically delete the artifact directories, so you may want to prune them after a debugging session. diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 1c70f4012090e..15a9132aaa16c 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -101,6 +101,12 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, std::function* ep_graph_checker) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_json() || + QNNTestEnvironment::GetInstance().dump_dlc()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; verification_params.fp32_abs_err = fp32_abs_err; @@ -110,6 +116,10 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -123,7 +133,27 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); + + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); + } + TryEnableQNNSaver(provider_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + provider_options["dump_qnn_ir_dlc"] = "1"; + provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = output_dir.string(); + } RunAndVerifyOutputsWithEP(AsByteSpan(model_data.data(), model_data.size()), "QNN_EP_TestLogID", QnnExecutionProviderWithOptions(provider_options), helper.feeds_, verification_params, @@ -134,11 +164,21 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, logging::Severity log_severity, std::function* ep_graph_checker) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -152,7 +192,27 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); + + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); + } + TryEnableQNNSaver(provider_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + provider_options["dump_qnn_ir_dlc"] = "1"; + provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = output_dir.string(); + } SessionOptions so; so.session_logid = "QNN_EP_TestLogID"; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index aeb3a9a114871..4d4f795d161b1 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -499,6 +499,77 @@ struct QDQTolerance { float value; }; +class QNNTestEnvironment { + public: + // Delete copy constructor and assignment operator + QNNTestEnvironment(const QNNTestEnvironment&) = delete; + QNNTestEnvironment& operator=(const QNNTestEnvironment&) = delete; + + // Static method to get the singleton instance + static QNNTestEnvironment& GetInstance() { + static QNNTestEnvironment instance; + return instance; + } + + bool dump_onnx() const { return dump_onnx_; } + bool dump_json() const { return dump_json_; } + bool dump_dlc() const { return dump_dlc_; } + bool verbose() const { return verbose_; } + + std::filesystem::path CreateTestcaseDirs() { + std::string test_suite_name = ::testing::UnitTest::GetInstance()->current_test_info()->test_suite_name(); + std::string test_name = ::testing::UnitTest::GetInstance()->current_test_info()->name(); + std::filesystem::path output_dir = std::filesystem::current_path() / (test_suite_name + "_" + test_name); + std::filesystem::create_directories(output_dir); + + return output_dir; + } + + private: + // Private constructor for singleton + QNNTestEnvironment() { + ParseEnvironmentVars(); + } + + // Helper function to check if an environment variable is set + bool IsEnvVarSet(const char* name) { + const char* value = std::getenv(name); + if (value == nullptr) { + return false; + } + + // Consider the variable set if it's not empty and not "0" + return *value != '\0' && *value != '0'; + } + + void ParseEnvironmentVars() { + if (IsEnvVarSet("QNN_DUMP_ONNX")) { + std::cout << "[QNN only] ONNX model dumping enabled via environment variable." << std::endl; + dump_onnx_ = true; + } + + if (IsEnvVarSet("QNN_DUMP_JSON")) { + std::cout << "[QNN only] Json QNN Graph dumping enabled via environment variable." << std::endl; + dump_json_ = true; + } + + if (IsEnvVarSet("QNN_DUMP_DLC")) { + std::cout << "[QNN only] DLC dumping enabled via environment variable." << std::endl; + dump_dlc_ = true; + } + + if (IsEnvVarSet("QNN_VERBOSE")) { + std::cout << "Verbose enabled via environment variable." << std::endl; + verbose_ = true; + } + } + + bool dump_onnx_ = false; + bool dump_json_ = false; + bool dump_dlc_ = false; + bool verbose_ = false; +}; + /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: * @@ -529,15 +600,21 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}, std::function* qnn_ep_graph_checker = nullptr) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); - - // Uncomment to dump LOGGER() output to stdout. - // logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -551,8 +628,11 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); - // Uncomment to save f32 model to disk for debugging. - // ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, ToPathString("cmp_accuracy.f32.onnx"))); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); + } // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; @@ -594,11 +674,27 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(qdq_model.MainGraph().Resolve()); qdq_model.ToProto().SerializeToString(&qdq_model_data); - // Uncomment to save QDQ model to disk for debugging. - // ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, ToPathString("cmp_accuracy.qdq.onnx"))); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_qdq_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx QDQ model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, dump_path)); + } bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + qnn_options["dump_qnn_ir_dlc"] = "1"; + qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + qnn_options["dump_json_qnn_graph"] = "1"; + qnn_options["json_qnn_graph_dir"] = output_dir.string(); + } std::vector qnn_qdq_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; @@ -743,11 +839,21 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, logging::Severity log_severity = logging::Severity::kERROR, const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -760,6 +866,12 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); + } + // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; InferenceModel(f32_model_data, "f32_model_logger", {}, ExpectedEPNodeAssignment::All, @@ -796,8 +908,27 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f16_model.MainGraph().Resolve()); f16_model.ToProto().SerializeToString(&f16_model_data); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f16_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float16 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f16_model, dump_path)); + } + bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + qnn_options["dump_qnn_ir_dlc"] = "1"; + qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + qnn_options["dump_json_qnn_graph"] = "1"; + qnn_options["json_qnn_graph_dir"] = output_dir.string(); + } std::vector qnn_f16_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; From 4c43c6697f2ab6c479506398aef2c265ec59d2ec Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 26 Nov 2025 21:29:47 -0500 Subject: [PATCH 09/20] [webgpu] Use multiplication instead of pow if exponent is 2 (#26667) ### Description More accurately compute Pow(2.0) on WebGPU EP. Reproduction script: ```py from onnx import helper, TensorProto import onnxruntime as ort import numpy as np # 1. Create the ONNX model # Define input and output input_info = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1]) output_info = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1]) # Create a constant tensor for the exponent (2.0) exponent_tensor = helper.make_tensor('exponent', TensorProto.FLOAT, [], [2.0]) exponent_node = helper.make_node('Constant', [], ['exponent_out'], value=exponent_tensor) # Create the Pow node # Pow takes two inputs: Base (X) and Power (exponent_out) pow_node = helper.make_node( 'Pow', inputs=['X', 'exponent_out'], outputs=['Y'], name='PowNode' ) # Create the graph graph_def = helper.make_graph( [exponent_node, pow_node], 'test-model', [input_info], [output_info] ) # Create the model model_def = helper.make_model(graph_def, producer_name='onnx-example') opset = model_def.opset_import[0] opset.version = 13 # Ensure opset version supports the operations # 2. Convert model to string (bytes) model_str = model_def.SerializeToString() # 3. Prepare input data np.random.seed(0) input_data = np.array([[-2e3]], dtype=np.float32) # 4. Run on CPUExecutionProvider sess_cpu = ort.InferenceSession(model_str, providers=['CPUExecutionProvider']) res_cpu = sess_cpu.run(['Y'], {'X': input_data})[0] print("CPU Result:", res_cpu) # 5. Run on WebGpuExecutionProvider sess_webgpu = ort.InferenceSession(model_str, providers=['WebGpuExecutionProvider']) res_webgpu = sess_webgpu.run(['Y'], {'X': input_data})[0] print("WebGPU Result:", res_webgpu) # Compare results diff = np.abs(res_cpu - res_webgpu) max_diff = diff.max().item() assert max_diff < 1e-5, f"Results do not match within tolerance! Max diff: {max_diff}" print("Results match!") ``` currently produces ``` CPU Result: [[4.e+06]] WebGPU Result: [[3.999999e+06]] --------------------------------------------------------------------------- AssertionError Traceback (most recent call last) Cell In[1], [line 56](vscode-notebook-cell:?execution_count=1&line=56) 54 diff = np.abs(res_cpu - res_webgpu) 55 max_diff = diff.max().item() ---> [56](vscode-notebook-cell:?execution_count=1&line=56) assert max_diff < 1e-5, f"Results do not match within tolerance! Max diff: {max_diff}" 57 print("Results match!") AssertionError: Results do not match within tolerance! Max diff: 1.0 ``` but with this PR: ``` CPU Result: [[4.e+06]] WebGPU Result: [[4.e+06]] Results match! ``` ### Motivation and Context Leads to downstream issues/inaccuracies for certain models, especially those which have larger values to compute pow(x,2) for. cc @guschmue --- .../providers/webgpu/math/binary_elementwise_ops.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 82645e30082e6..3c974ef5133c0 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -322,11 +322,14 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { round_str = "round"; } - std::string use_sqrt_for_pow; + std::string use_pow_shortcut; if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + // use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0 // use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5 - use_sqrt_for_pow = - " else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" + use_pow_shortcut = + " else if (b == 2.0) {\n" + " return a * a;\n" + " } else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" " return sqrt(a);\n" " }\n"; } @@ -337,7 +340,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { " } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n" " return input_a_element_t(pow(f32(a), b)); // NaN\n" " }\n" - << use_sqrt_for_pow + << use_pow_shortcut << " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n" << "}\n" "fn pow_v(a : vec4, b : vec4) -> vec4 {\n" From 71f4e67d8e42615784a322d55993eed066a20c94 Mon Sep 17 00:00:00 2001 From: Christian Bourjau Date: Mon, 1 Dec 2025 07:25:59 +0100 Subject: [PATCH 10/20] Avoid creation of temporary protobuf object (#26681) ### Description While profiling session creation time for large graphs (number of nodes, not size of tensors), we noticed that the creations and subsequent destructions of protobuf objects were the major hotspot. This PR avoids its creation. Signed-off-by: Christian Bourjau --- onnxruntime/core/framework/allocation_planner.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index e77496b6e8196..1c80d83f99feb 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -499,8 +499,7 @@ class PlannerImpl { /*! \brief Given a tensor-type, return the size of an element of the tensor. */ static size_t GetElementSize(const DataType& tensor_type) { - const TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); - MLDataType ml_data_type = DataTypeImpl::TypeFromProto(type_proto); + MLDataType ml_data_type = DataTypeImpl::GetDataType(*tensor_type); const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); ORT_ENFORCE(nullptr != tensor_type_base); MLDataType elt_type = tensor_type_base->GetElementType(); From eab7c9ac720b9673b60bd31340e9cf4a3568effd Mon Sep 17 00:00:00 2001 From: Christian Bourjau Date: Mon, 1 Dec 2025 07:27:19 +0100 Subject: [PATCH 11/20] Use `std::string_view` directly as key to `absl::flat_hash_map::find` (#26682) ### Description Use `std::string_view` directly as key in `find` method of `flat_hash_map`. This part of the absl documentation may provide further insights: https://abseil.io/docs/cpp/guides/container#heterogeneous-lookup ### Motivation and Context We noticed this when profiling the session creation of large models (in terms of the number of nodes). Signed-off-by: Christian Bourjau --- onnxruntime/core/framework/ort_value_name_idx_map.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/ort_value_name_idx_map.h b/onnxruntime/core/framework/ort_value_name_idx_map.h index 76e7e369514d4..6035dc4e85242 100644 --- a/onnxruntime/core/framework/ort_value_name_idx_map.h +++ b/onnxruntime/core/framework/ort_value_name_idx_map.h @@ -33,7 +33,7 @@ class OrtValueNameIdxMap { common::Status GetIdx(std::string_view name, int& idx) const { idx = -1; - auto it = map_.find(std::string(name)); + auto it = map_.find(name); if (it == map_.end()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not find OrtValue with name '", name, "'"); } From 0593840775e3075cfcebf6f0450d86a1f6bb9670 Mon Sep 17 00:00:00 2001 From: Xiaofei Han Date: Tue, 2 Dec 2025 00:34:28 +0800 Subject: [PATCH 12/20] [webgpu] Convert i32 to u32 in uniforms (#26676) In debug mode, `webgpu_context.cc:257 Run Uniform variable[5] (head_size) data type mismatch in program "SplitPackedQKVWithRotaryEmbeddingAndCopyKV", Expected: u32, Actual: i32`. No issue in release mode. Convert i32 to u32 to avoid this issue. --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 4 ++-- onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 3f5b28f92f55d..2a67dfdb07912 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -571,8 +571,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {head_size_vec}, - {half_rotary_embedding_dim_vec}, + {static_cast(head_size_vec)}, + {static_cast(half_rotary_embedding_dim_vec)}, {present_sequence_length}, {tile_size}, {static_cast(dispatch_size)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 05717fd2fe686..416a895e61745 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -128,8 +128,8 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {head_size_vec}, - {half_rotary_embedding_dim_vec}, + {static_cast(head_size_vec)}, + {static_cast(half_rotary_embedding_dim_vec)}, {static_cast(dispatch_size)}, }) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); From 5c245bc9d99881b7d8860d40d7599babc265a0c0 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 1 Dec 2025 11:52:10 -0500 Subject: [PATCH 13/20] [webgpu] Fix BatchNormalization ShapeInferenceError for 2D inputs (#26659) ### Description Test model (happens with any 2D inputs): [2191__visual_projection_visual_projection.1_BatchNormalization.onnx.zip](https://github.com/user-attachments/files/23758390/2191__visual_projection_visual_projection.1_BatchNormalization.onnx.zip) Command: ``` python -c "import onnxruntime as ort; ort.InferenceSession('2191__visual_projection_visual_projection.1_BatchNormalization.onnx', providers=['WebGpuExecutionProvider'])" ``` Before (failure): ``` Op (BatchNormalization) [ShapeInferenceError] Tensor must have at least 3 dimensions to convert between channels first and channels last. ``` After (success): ``` (nothing, meaning success) ``` ### Motivation and Context This fixes BatchNormalization on WebGPU, matching CPU version. cc @guschmue --- .../core/graph/contrib_ops/nhwc_inference_context.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h index bc52a45adfd43..94ef87fb069af 100644 --- a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h +++ b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h @@ -83,7 +83,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nchw_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input if (rank < 3) { - fail_shape_inference("Output tensor must have at least 3 dimensions"); + *nhwc_tp.mutable_tensor_type()->mutable_shape() = nchw_shape; + return; } // Convert output shape from N, C, H {, W, ...} to N, H {, W, ...}, C. @@ -105,8 +106,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nhwc_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input. if (rank < 3) { - fail_shape_inference( - "Tensor must have at least 3 dimensions to convert between channels first and channels last."); + *nchw_tp.mutable_tensor_type()->mutable_shape() = nhwc_shape; + return; } // Convert input shape from {N, H, W, ..., C} to {N, C, H, W, ...}. From a19954cd9a0d666dcb438e364ce1758f3548369c Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 1 Dec 2025 09:04:20 -0800 Subject: [PATCH 14/20] Clear cuda error on unsupported CudaMemPool test (#26629) ### Description CudaMemPool test checks if it is supported in a given environment. We need to clear the error not to affect subsequent tests. ### Motivation and Context Potential test failure. --- .../providers/cuda/cuda_mempool_arena_test.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc index 70c7a5b2bcdcb..5deef01cd783e 100644 --- a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc +++ b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc @@ -22,10 +22,17 @@ namespace test { // --------- Helpers --------- +// cuda errors are sticky and may affect subsequent API calls. +// we want to clear the error if when supported check fails. +void ClearCudaError() { + ORT_IGNORE_RETURN_VALUE(::cudaGetLastError()); +} + static bool IsCudaMemPoolSupported() { int ort_cuda_rt_version = 0; cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version); if (cuda_status != cudaSuccess) { + ClearCudaError(); return false; } @@ -36,6 +43,7 @@ static bool IsCudaMemPoolSupported() { int ort_cuda_driver_version = 0; cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version); if (cuda_status != cudaSuccess) { + ClearCudaError(); return false; } @@ -65,9 +73,10 @@ static bool IsCudaMemPoolSupported() { cudaMemPool_t pool; auto cuda_error = cudaMemPoolCreate(&pool, &props); if (cuda_error != cudaSuccess) { + ClearCudaError(); return false; } - cuda_error = cudaMemPoolDestroy(pool); + ORT_IGNORE_RETURN_VALUE(cudaMemPoolDestroy(pool)); return true; } @@ -80,7 +89,9 @@ static ::cudaStream_t NewCudaStream() { } static void DestroyCudaStream(::cudaStream_t s) { - if (s) (void)::cudaStreamDestroy(s); + if (s) { + EXPECT_EQ(cudaSuccess, ::cudaStreamDestroy(s)); + } } static void TouchDevice(void* p, size_t bytes, ::cudaStream_t s, unsigned char value = 0xAB) { From 55a38c598f5199f8482c11485e1277799eab3117 Mon Sep 17 00:00:00 2001 From: chunghow-qti Date: Tue, 2 Dec 2025 07:12:17 +0800 Subject: [PATCH 15/20] [QNN-EP] Include detailed error message in the returned status (#26546) ### Description The original error message only shows: "Failed to setup QNN input tensors for graph: " This change adds more detailed error information by logging the failure reason from [SetupTensors](https://github.com/microsoft/onnxruntime/blob/ea55c160a36d658eae61a4c7aeda6cb55dd54dec/onnxruntime/core/providers/qnn/builder/qnn_model.cc#L386), making it easier to debug issues. ### Motivation and Context User requires detailed error logging for the ORT online context binary generation. --- onnxruntime/core/providers/qnn/builder/qnn_model.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 85901ab6fdfec..8973a4efa8ba1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -222,14 +222,14 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name(); + const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name(); + const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } From ee77417841dad90627fae49fd5d0aae6756d2ad9 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 1 Dec 2025 17:17:33 -0800 Subject: [PATCH 16/20] add support for int32_t in webgpu / slice (#26693) fix for https://github.com/microsoft/onnxruntime/issues/26690 --- .../core/providers/webgpu/tensor/slice.cc | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 7e8b434431781..5f59fecc425e2 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -92,14 +92,28 @@ Status SliceProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +static std::vector getInt64Input(const Tensor* tensor) { + if (tensor->IsDataType()) { + return std::vector(tensor->DataAsSpan().begin(), tensor->DataAsSpan().end()); + } + ORT_ENFORCE(tensor->IsDataType(), "Expected tensor of type int32 or int64"); + std::vector result; + auto span = tensor->DataAsSpan(); + result.reserve(span.size()); + for (auto v : span) { + result.push_back(static_cast(v)); + } + return result; +} + Status Slice::ComputeInternal(ComputeContext& context) const { // READ INPUTS const Tensor* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); auto input_rank = input_shape.NumDimensions(); - auto starts_raw = attr_starts_.empty() ? context.Input(1)->DataAsSpan() : gsl::make_span(attr_starts_); - auto ends_raw = attr_ends_.empty() ? context.Input(2)->DataAsSpan() : gsl::make_span(attr_ends_); + auto starts_raw = attr_starts_.empty() ? getInt64Input(context.Input(1)) : attr_starts_; + auto ends_raw = attr_ends_.empty() ? getInt64Input(context.Input(2)) : attr_ends_; ORT_ENFORCE(starts_raw.size() == ends_raw.size(), "starts and ends must have the same size"); @@ -126,7 +140,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { axes_default.push_back(i); } } - auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? gsl::make_span(axes_default) : axes_tensor->DataAsSpan()) : gsl::make_span(attr_axes_); + auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? axes_default : getInt64Input(axes_tensor)) : attr_axes_; std::vector steps_default; if (steps_tensor == nullptr) { @@ -135,7 +149,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { steps_default.push_back(1); } } - auto steps_raw = steps_tensor == nullptr ? gsl::make_span(steps_default) : steps_tensor->DataAsSpan(); + auto steps_raw = steps_tensor == nullptr ? steps_default : getInt64Input(steps_tensor); // get final axes std::vector axes, axes_fixed; From 458e1bb380f211c4a2ac313c40eb680a7be66fdd Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Tue, 2 Dec 2025 09:18:36 +0800 Subject: [PATCH 17/20] [webgpu] Remove `global_id` and `workgroup_id` in gemm_utils.cc (#26662) ### Description This patch replaces `global_id` and `workgroup_id` with `logical_global_id` and `logical_workgroup_id` which are computed from `workgroup_idx` and the dispatch workgroup sizes set in `ProgramBase::SetDispatchGroupSize()`. ### Motivation and Context We shouldn't use `global_id` or `workgroup_id` directly because the dispatch workgroup sizes may be normalized in `ProgramManager::NormalizeDispatchGroupSize()`. --- .../core/providers/webgpu/math/gemm_packed.cc | 15 +++--- .../core/providers/webgpu/math/gemm_packed.h | 5 +- .../core/providers/webgpu/math/gemm_utils.cc | 46 +++++++++++++------ .../core/providers/webgpu/math/matmul.cc | 4 +- .../providers/webgpu/math/matmul_packed.h | 5 +- .../core/providers/webgpu/nn/conv2d_mm.cc | 5 +- .../core/providers/webgpu/nn/conv2d_mm.h | 5 +- 7 files changed, 58 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 6aefa90a59285..c26b58a7af1f4 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -93,18 +93,21 @@ Status ApplyGemmPacked(const Tensor* a, } const uint32_t TILE_SIZE = 32; - const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE; program.CacheHint(alpha, transA, transB, c_is_scalar) .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) - .SetDispatchGroupSize(num_tile_n, num_tile_m, 1) + .SetDispatchGroupSize(dispatch_x, dispatch_y, 1u) .SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z) .AddUniformVariables({{alpha}, {beta}, - {M}, /* dim_a_outer */ - {N}, /* dim_b_outer */ - {K}} /*dim_inner */ + {M}, /* dim_a_outer */ + {N}, /* dim_b_outer */ + {K}, /*dim_inner */ + {dispatch_x}, /* logical_dispatch_x */ + {dispatch_y}, /* logical_dispatch_y */ + {1u}} /* logical_dispatch_z */ ); return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.h b/onnxruntime/core/providers/webgpu/math/gemm_packed.h index dce5164693aa8..cb89ccefba313 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.h @@ -32,7 +32,10 @@ class GemmProgram final : public Program { {"beta", ProgramUniformVariableDataType::Float32}, {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 7cbc7f6a4a821..89718149cea88 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -117,6 +117,20 @@ void HandleMatMulWithSplitK( } } +// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in +// `ProgramBase.SetDispatchGroupSize()` may be normalized in +// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use +// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. +void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) { + shader.MainFunctionBody() + << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" + << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" + << " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n" + << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; +} + } // namespace void MatMulReadFnSource(ShaderHelper& shader, @@ -274,20 +288,22 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << "const innerElementSize = " << inner_elements_size << ";\n" << "const tileInner = " << tile_inner << ";\n"; + InitializeLogicalWorkgroupIDAndGlobalID(shader); + shader.MainFunctionBody() << " let localRow = i32(local_id.y);\n" << " let tileRow = localRow * rowPerThread;\n" << " let tileCol = i32(local_id.x);\n" - << " let globalRow = i32(global_id.y) * rowPerThread;\n" - << " let globalCol = i32(global_id.x);\n" - << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << " let globalRow = i32(logical_global_id.y) * rowPerThread;\n" + << " let globalCol = i32(logical_global_id.x);\n" + << " let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" + << " let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" << " var acc: array, rowPerThread>;\n"; if (split_k) { // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from - // `kSplitK * i32(global_id.z)`. + // `kSplitK * i32(logical_global_id.z)`. // // For example: considering computing Y = (X * W + B) in one workgroup. // Let kSplitK = 2, B = [d1, d2] @@ -305,15 +321,15 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) // Workgroup3: compute (C1 * C2) // In each workgroup: - // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z` + // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `logical_global_id.z` // - When the computation in each workgroup is completed, add the result to Y with several // atomic built-in functions in `HandleMatMulWithSplitK()`. shader.MainFunctionBody() << "const kSplitK = " << split_dim_inner << ";\n" << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" - << " var kStart = kSplitK * i32(global_id.z);\n" + << " var kStart = kSplitK * i32(logical_global_id.z);\n" - // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate + // When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate // the index of split-k instead of batch. << " let batch = 0;\n" << " let batchIndices = 0u;\n"; @@ -321,7 +337,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, shader.MainFunctionBody() << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" - << " let batch = i32(global_id.z);\n" + << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : ""); } @@ -498,7 +514,9 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, << "const colPerThread = " << elements_per_thread_x << ";\n" << "const tileInner = " << tile_inner << ";\n"; - shader.MainFunctionBody() << " let batch = i32(global_id.z);\n" + InitializeLogicalWorkgroupIDAndGlobalID(shader); + + shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" @@ -507,10 +525,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, shader.MainFunctionBody() << "let tileRow = i32(local_id.y) * rowPerThread;\n" << "let tileCol = i32(local_id.x) * colPerThread;\n" - << "let globalRow = i32(global_id.y) * rowPerThread;\n" - << "let globalCol = i32(global_id.x) * colPerThread;\n" - << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << "let globalRow = i32(logical_global_id.y) * rowPerThread;\n" + << "let globalCol = i32(logical_global_id.x) * colPerThread;\n" + << "let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" + << "let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 55c2c5773cc1f..72dd235eb820a 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -256,8 +256,6 @@ Status ComputeMatMul(ComputeContext* context, // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the // number of splits along `dim_inner`. - // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize - // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`. split_dim_inner = split_k_config.GetSplitDimInner(); dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; @@ -271,7 +269,7 @@ Status ComputeMatMul(ComputeContext* context, .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) - .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z) diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 143ba61c99e13..dbd193bc38f58 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -24,7 +24,10 @@ class MatMulProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); bool NeedSplitK() const; diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc index 2d5424c52a3f2..c66f2cbd582d9 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc @@ -226,7 +226,10 @@ Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::v {static_cast(dim_inner)}, {pads}, {strides}, - {dilations}}); + {dilations}, + {dispatch[0]}, + {dispatch[1]}, + {dispatch[2]}}); return program; } diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h index d7cc08aae26f3..e161bffb0c503 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h @@ -38,7 +38,10 @@ class Conv2dMMProgram final : public Program { {"dim_inner", ProgramUniformVariableDataType::Uint32}, {"pads", ProgramUniformVariableDataType::Uint32}, {"strides", ProgramUniformVariableDataType::Uint32}, - {"dilations", ProgramUniformVariableDataType::Uint32}); + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); private: const Activation& activation_; From c156e933b34876c959a4b4c611d2c7dd8e71cafc Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 2 Dec 2025 13:44:44 -0500 Subject: [PATCH 18/20] [webgpu] Correct definition of large numbers, fixes softmax(max_negative_number) in float32 (#26670) ### Description The correct definition of the most negative number is `-3.40282346638528e+38`, according to IEEE 754, but it is being incorrectly registered inline as a truncated version `-3.402823e+38f`. ```py >>> import numpy as np >>> np.finfo(np.float32).min np.float32(-3.4028235e+38) >>> np.finfo(np.float32).min.item() -3.4028234663852886e+38 ``` For this reason, values less than this threshold were handled incorrectly. While this may seem like a small/irrelevant detail, it's essential in attention masking, where we do in fact use this value, leading to large numerical errors down the line. Reproduction: ```py from onnx import helper, TensorProto import onnxruntime as ort import numpy as np # 1. Create the ONNX model # Define input and output input_shape = [1, 2] input_info = helper.make_tensor_value_info('X', TensorProto.FLOAT, input_shape) output_info = helper.make_tensor_value_info('Y', TensorProto.FLOAT, input_shape) # Create the Softmax node # Softmax takes one input: X softmax_node = helper.make_node( 'Softmax', inputs=['X'], outputs=['Y'], name='SoftmaxNode', axis=-1 # Default axis is -1, usually applied to the last dimension ) # Create the graph graph_def = helper.make_graph( [softmax_node], 'test-model', [input_info], [output_info] ) # Create the model model_def = helper.make_model(graph_def, producer_name='onnx-example') opset = model_def.opset_import[0] opset.version = 13 # Ensure opset version supports the operations # 2. Convert model to string (bytes) model_str = model_def.SerializeToString() # 3. Prepare input data np.random.seed(0) input_data = np.array( [[-3.40282346638528e+38, -3.40282346638528e+38]] # [[-3.4028234663852886e+38, -3.4028234663852886e+38]] ).astype(np.float32) print(input_data.tolist()) # 4. Run on CPUExecutionProvider sess_cpu = ort.InferenceSession(model_str, providers=['CPUExecutionProvider']) res_cpu = sess_cpu.run(['Y'], {'X': input_data})[0] print("CPU Result:", res_cpu) # 5. Run on WebGpuExecutionProvider sess_webgpu = ort.InferenceSession(model_str, providers=['WebGpuExecutionProvider']) res_webgpu = sess_webgpu.run(['Y'], {'X': input_data})[0] print("WebGPU Result:", res_webgpu) # Compare results diff = np.abs(res_cpu - res_webgpu) max_diff = diff.max().item() print(diff) print(f"Max diff: {max_diff}") assert max_diff < 1e-5, f"Results do not match within tolerance! Max diff: {max_diff}" print("Results match!") ``` Before: ``` [[-3.4028234663852886e+38, -3.4028234663852886e+38]] CPU Result: [[0.5 0.5]] WebGPU Result: [[0. 0.]] [[0.5 0.5]] Max diff: 0.5 AssertionError: Results do not match within tolerance! Max diff: 0.5 ``` After: ``` [[-3.4028234663852886e+38, -3.4028234663852886e+38]] CPU Result: [[0.5 0.5]] WebGPU Result: [[0.5 0.5]] [[0. 0.]] Max diff: 0.0 Results match! ``` cc @guschmue --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 4 ++-- js/web/lib/wasm/jsep/webgpu/ops/softmax.ts | 2 +- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 6 +++--- .../contrib_ops/webgpu/bert/flash_attention.wgsl.template | 2 +- .../webgpu/bert/flash_attention_decode_qkt.wgsl.template | 2 +- .../bert/flash_attention_decode_split_vx.wgsl.template | 2 +- onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template | 2 +- onnxruntime/core/providers/js/operators/unary.cc | 2 +- .../core/providers/vsinpu/builders/impl/clip_op_builder.cc | 4 ++-- onnxruntime/core/providers/webgpu/math/softmax.cc | 2 +- 10 files changed, 14 insertions(+), 14 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 6a8dffb73fa08..f0f7527f665b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -360,7 +360,7 @@ const createInPlaceSoftmaxProgramInfo = ( let local_offset = local_idx * uniforms.elements_per_thread; let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset; let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'}; - var thread_max_vector = ${f32Type}(-3.402823e+38f); + var thread_max_vector = ${f32Type}(-3.4028234663852886e+38f); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } @@ -378,7 +378,7 @@ const createInPlaceSoftmaxProgramInfo = ( })()}; workgroupBarrier(); - var max_value = f32(-3.402823e+38f); + var max_value = f32(-3.4028234663852886e+38f); for (var i = 0u; i < ${WG}; i++) { max_value = max(thread_max[i], max_value); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 2056416873df5..f6882280e91df 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -81,7 +81,7 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt // 6.2.4 in wgsl spec const threadMaxDecl = tensorTypeToWsglStorageType(transposedInput.dataType) === 'f32' - ? `var threadMax = ${valueType}(-3.402823e+38f);` + ? `var threadMax = ${valueType}(-3.4028234663852886e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (shaderHelper: ShaderHelper) => ` var rowMaxShared : ${valueType}; diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index a5ab63d74df24..130dd0c25a880 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -165,7 +165,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let query_pos = m + local_id.y + past_sequence_length;\n" << " let key_pos = n + local_id.x;\n" << " if (key_pos > query_pos) {\n" - << " sum = -3.40282e+38; // Set to very negative value for masking\n" + << " sum = -3.4028234663852886e+38; // Set to very negative value for masking\n" << " }\n"; } @@ -272,7 +272,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let effective_seq_length = seq_causal_length;\n"; } shader.MainFunctionBody() - << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "var thread_max_vector = f32_val_t(-3.4028234663852886e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" << " let actual_pos = local_offset + i + start_offset;\n" << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" @@ -289,7 +289,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } else if (use_smooth_softmax_) { shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n"; } else { - shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n"; + shader.MainFunctionBody() << "var max_value = f32(-3.4028234663852886e+38f);\n"; } shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index a5922ec9512fd..ff8e4ecc08bab 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -26,7 +26,7 @@ fn get_total_sequence_length() -> u32 { #if is_fp16 const min_value = q_element_t(-65504.0); #else -const min_value = q_element_t(-3.402823e+38f); +const min_value = q_element_t(-3.4028234663852886e+38f); #endif // For max performance max_k_step should be the same as sg_size, however we might run out of registers diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template index c6f768beffa0f..ac9a157492007 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template @@ -93,7 +93,7 @@ $MAIN { if (local_idx == 0u) { // Calculate the max and sum in current split. - var l_max = f32(-3.402823e+38f); + var l_max = f32(-3.4028234663852886e+38f); var l_sum = f32(0); for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { l_max = max(l_max, f32(tile_qk[i])); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template index 37cf7e8f11b1f..a113e96130985 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template @@ -54,7 +54,7 @@ $MAIN { // Calculate the global max and sum in qk. if (head_idx < uniforms.num_heads) { - var g_max = f32(-3.402823e+38f); + var g_max = f32(-3.4028234663852886e+38f); var g_sum = f32(0); for (var i = 0u; i < num_total_seq_length_tile; i++) { diff --git a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template index 1214777009a8d..6e0d4c7299793 100644 --- a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template @@ -18,7 +18,7 @@ const K: u32 = k; #if is_fp16 const MAX_FLOAT: f16 = 65504.0; #else -const MAX_FLOAT: f32 = 3.402823466e+38; +const MAX_FLOAT: f32 = 3.4028234663852886e+38; #endif var shared_vals: array; diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index ef977161bcc37..26144e6ba3995 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -126,7 +126,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation -JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.4028234663852886e+38f, max, -3.4028234663852886e+38f) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc index 85096d0e262d7..9948069c6779b 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc @@ -78,8 +78,8 @@ bool ClipOpBuilder::HandleBuildOp(vsi::npu::GraphEP* graph_ep, LOGS_DEFAULT(INFO) << "Creating Clip Op."; if (node_unit.SinceVersion() <= 6) { NodeAttrHelper helper(node_unit.GetNode()); - auto min = helper.Get("min", -3.402e+38f); - auto max = helper.Get("max", 3.402e+38f); + auto min = helper.Get("min", -3.4028234663852886e+38f); + auto max = helper.Get("max", 3.4028234663852886e+38f); auto op = graph_ep->GetGraph()->CreateOperation(min, max); (*op).BindInputs(inputs).BindOutputs(outputs); graph_ep->GetOps().push_back(std::move(op)); diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc index 2f34aa21c8309..bf3bb53341418 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -64,7 +64,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { int components = input.NumComponents(); const std::string thread_max_decl = is_fp32_ - ? "var thread_max = x_value_t(-3.402823e+38f);\n" + ? "var thread_max = x_value_t(-3.4028234663852886e+38f);\n" : "var thread_max = x_value_t(-65504.0h);\n"; // Define shared memory for row max and row sum From 2b659e4d1a8a16574b87804c4783e1d36bad7d4d Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 2 Dec 2025 11:16:59 -0800 Subject: [PATCH 19/20] [TRT/TRT RTX EP] Fix bug for missing outputs in the returning ComputeCapability/IndexedSubGraph (#26444) ### Description For TRT EP's `GetCapability()`, in some case, the `GetSubGraph()` won't add graph's output to the `ComputeCapability/IndexedSubGraph` returning to ORT. The issue if from following code: ````c++ ... if (node->GetOutputEdgesCount() > node->OutputDefs().size()) { ... // execute here } else { ... if (graph_output_names.find(output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; // missing this } } ```` Update TRT RTX EP as well. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/25373 --- .../nv_tensorrt_rtx/nv_execution_provider.cc | 77 ++++++++++------- .../tensorrt/tensorrt_execution_provider.cc | 77 ++++++++++------- .../nv_tensorrt_rtx/nv_basic_test.cc | 42 ++++++++++ .../providers/tensorrt/tensorrt_basic_test.cc | 49 ++++++++++- .../test/testdata/node_output_not_used.onnx | Bin 0 -> 189 bytes .../test/testdata/node_output_not_used.py | 43 ++++++++++ .../topk_and_multiple_graph_outputs.onnx | Bin 0 -> 393 bytes .../topk_and_multiple_graph_outputs.py | 78 ++++++++++++++++++ 8 files changed, 303 insertions(+), 63 deletions(-) create mode 100644 onnxruntime/test/testdata/node_output_not_used.onnx create mode 100644 onnxruntime/test/testdata/node_output_not_used.py create mode 100644 onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx create mode 100644 onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index e2a8005aba1da..d148c4191d5d7 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1407,9 +1407,30 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs; + + // These maps store the inputs and outputs of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map fused_inputs, fused_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map fused_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + std::unordered_set erased; + + // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', + // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. + // Items added earlier receive a smaller order index than items added later. + // When constructing the final sub_graph's input or output lists, entries with smaller + // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -1428,7 +1449,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -1443,7 +1464,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -1464,39 +1485,33 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - if (node_set.find(node_idx) != node_set.end()) { - const auto& iter = fused_inputs.find(output); - if (iter != fused_inputs.end()) { - fused_inputs.erase(iter); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } - fused_outputs[output] = output_order++; - } - } else { - fused_outputs_to_add[output] = output_order++; + + if (node_set.find(node_idx) == node_set.end()) { + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + fused_outputs_to_add.insert({output, output_order++}); } } - } else { - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } - // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list - else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } + } - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - fused_outputs[output] = output_order++; - } + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + fused_outputs.insert({output, output_order++}); } } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } } } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index cd0c0e4bffdb5..e5b48da33fbc3 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2035,9 +2035,30 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs; + + // These maps store the inputs and outputs of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map fused_inputs, fused_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map fused_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + std::unordered_set erased; + + // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', + // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. + // Items added earlier receive a smaller order index than items added later. + // When constructing the final sub_graph's input or output lists, entries with smaller + // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -2056,7 +2077,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -2071,7 +2092,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -2092,39 +2113,33 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - if (node_set.find(node_idx) != node_set.end()) { - const auto& iter = fused_inputs.find(output); - if (iter != fused_inputs.end()) { - fused_inputs.erase(iter); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } - fused_outputs[output] = output_order++; - } - } else { - fused_outputs_to_add[output] = output_order++; + + if (node_set.find(node_idx) == node_set.end()) { + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + fused_outputs_to_add.insert({output, output_order++}); } } - } else { - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } - // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list - else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } + } - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - fused_outputs[output] = output_order++; - } + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + fused_outputs.insert({output, output_order++}); } } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } } } diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index d8cc56d738175..af9706855ee3c 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -203,6 +203,48 @@ TEST_P(TypeTests, IOTypes) { } } +TEST(NvExecutionProviderTest, TestSessionOutputs) { + /* + * Model #1: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + { + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + + auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 4); + } + + /* + * Model #2: + * + * "X" ---> Dropout ---> MatMul ---> "Y" + * ^ | + * | | + * "W" ------ ----> Can't be graph's output + * + */ + { + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + + auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 1); + } +} + INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, ::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 6a6545c68cb4f..dce0d570ec238 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -1,5 +1,6 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "onnxruntime_cxx_api.h" #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" @@ -18,6 +19,8 @@ using namespace std; using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; +extern std::unique_ptr ort_env; + namespace onnxruntime { namespace test { @@ -1360,5 +1363,49 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); } + +TEST(TensorrtExecutionProviderTest, TestSessionOutputs) { + /* + * Model #1: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + { + OrtTensorRTProviderOptionsV2 provider_options; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_TensorRT_V2(provider_options); + + auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 4); + } + + /* + * Model #2: + * + * "X" ---> Dropout ---> MatMul ---> "Y" + * ^ | + * | | + * "W" ------ ----> Can't be graph's output + * + */ + { + OrtTensorRTProviderOptionsV2 provider_options; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_TensorRT_V2(provider_options); + + auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 1); + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/node_output_not_used.onnx b/onnxruntime/test/testdata/node_output_not_used.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e2726182fddc2c265752e46346735c26e33add4b GIT binary patch literal 189 zcmd=lo3kgAWK)CKji3J%^!XPX8xOg}ig*dpFIGBN$2_zVfB*+Ak RNCFB*q6<2)a4`t*0ss-ID|-L{ literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/node_output_not_used.py b/onnxruntime/test/testdata/node_output_not_used.py new file mode 100644 index 0000000000000..d36d5e9cfd2f8 --- /dev/null +++ b/onnxruntime/test/testdata/node_output_not_used.py @@ -0,0 +1,43 @@ +import onnx +from onnx import TensorProto, helper + + +def create_model_with_node_output_not_used(model_path): + # Create graph + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) + w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + # Dropout node (two outputs) + dropout_node = helper.make_node( + "Dropout", + inputs=["X"], + outputs=["dropout_out", "dropout_mask"], + name="DropoutNode", + ) + + # MatMul node + matmul_node = helper.make_node( + "MatMul", + inputs=["dropout_out", "W"], + outputs=["Y"], + name="MatMulNode", + ) + + graph = helper.make_graph( + nodes=[dropout_node, matmul_node], + name="DropoutMatMulGraph", + inputs=[x, w], + outputs=[y], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)]) + + onnx.checker.check_model(model) + onnx.save(model, model_path) + + print(f"Model saved to: {model_path}") + + +if __name__ == "__main__": + create_model_with_node_output_not_used("node_output_not_used.onnx") diff --git a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx new file mode 100644 index 0000000000000000000000000000000000000000..340c3d420d5746844be0bd3769a174b4e69de801 GIT binary patch literal 393 zcmdW?8(U zIYJ{dP(TSp09}P@53)A4oW!KmoMI_v-~1FM5Fx|~a-n-sVnK!$HwU8tyA{(KCMQO3 zEp8x_k--V Date: Tue, 2 Dec 2025 14:51:36 -0800 Subject: [PATCH 20/20] [ROCM] Remove docker, contrib ops, ci scripts related to ROCM EP (#26697) ### Description This is follow up of https://github.com/microsoft/onnxruntime/pull/25181 to remove ROCM EP related files to avoid confusion. Documents will be updated later. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/26692 --- dockerfiles/Dockerfile.rocm | 24 - dockerfiles/README.md | 17 +- dockerfiles/scripts/install_rocm_deps.sh | 84 -- .../contrib_ops/rocm/bert/attention.cu | 215 ---- onnxruntime/contrib_ops/rocm/bert/attention.h | 33 - .../contrib_ops/rocm/bert/attention_impl.cu | 435 --------- .../contrib_ops/rocm/bert/attention_impl.h | 180 ---- .../contrib_ops/rocm/bert/attention_softmax.h | 465 --------- .../bert/batched_gemm_permute_pipelines.cuh | 125 --- .../impl.cuh | 177 ---- .../impl_fp16.cu | 60 -- .../impl_fp16_biased.cu | 60 -- .../impl_fp16_biased_biased.cu | 60 -- ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 915 ------------------ .../rocm/bert/decoder_attention_impl.h | 46 - .../contrib_ops/rocm/bert/elementwise.h | 84 -- .../rocm/bert/elementwise_impl/impl.cuh | 256 ----- .../bert/elementwise_impl/impl_fastgelu.cu | 9 - .../rocm/bert/elementwise_impl/impl_gelu.cu | 9 - .../rocm/bert/elementwise_impl/impl_relu.cu | 8 - .../contrib_ops/rocm/bert/gemm_fast_gelu.cc | 75 -- .../contrib_ops/rocm/bert/gemm_fast_gelu.h | 23 - .../rocm/bert/gemm_fast_gelu_ck.cuh | 133 --- .../rocm/bert/gemm_fast_gelu_common.h | 47 - .../rocm/bert/gemm_fast_gelu_impl.cu | 91 -- .../rocm/bert/gemm_fast_gelu_impl.h | 40 - .../rocm/bert/gemm_fast_gelu_tunable.cuh | 83 -- .../rocm/bert/group_query_attention.cu | 530 ---------- .../rocm/bert/group_query_attention.h | 38 - .../contrib_ops/rocm/bert/layer_norm.cuh | 270 ------ .../rocm/bert/multihead_attention.cu | 286 ------ .../rocm/bert/multihead_attention.h | 51 - .../contrib_ops/rocm/bert/skip_layer_norm.cc | 132 --- .../contrib_ops/rocm/bert/skip_layer_norm.h | 26 - .../rocm/bert/skip_layer_norm_impl.cu | 86 -- .../rocm/bert/skip_layer_norm_impl.h | 31 - .../rocm/bert/skip_layer_norm_impl_kernel.h | 162 ---- .../rocm/bert/skip_layer_norm_tunable_op.h | 161 --- .../rocm/bert/transformer_common.cc | 37 - .../rocm/bert/transformer_common.h | 46 - .../rocm/diffusion/group_norm_ck.cuh | 105 -- .../diffusion/group_norm_ck_impl/impl.cuh | 130 --- .../diffusion/group_norm_ck_impl/impl_fp16.cu | 39 - .../diffusion/group_norm_ck_impl/impl_fp32.cu | 39 - .../rocm/diffusion/group_norm_common.h | 56 -- .../rocm/diffusion/group_norm_impl.cu | 76 -- .../rocm/diffusion/group_norm_triton.cuh | 105 -- .../rocm/diffusion/group_norm_triton.py | 135 --- .../rocm/diffusion/group_norm_tunable_op.h | 220 ----- .../contrib_ops/rocm/diffusion/nhwc_conv.cc | 27 - onnxruntime/contrib_ops/rocm/fused_conv.cc | 439 --------- .../contrib_ops/rocm/math/gemm_float8.cu | 213 ---- .../contrib_ops/rocm/math/gemm_float8_ck.cuh | 276 ------ .../math/gemm_float8_ck_impl/add_instance.cu | 124 --- ...xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu | 97 -- ...k_f16_f8_f16_mk_kn_mn_instance_original.cu | 80 -- ...xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu | 94 -- ...k_f8_f16_f16_mk_kn_mn_instance_original.cu | 97 -- .../contrib_ops/rocm/rocm_contrib_kernels.cc | 347 ------- .../contrib_ops/rocm/rocm_contrib_kernels.h | 14 - .../github/linux/build_rocm_c_api_package.sh | 40 - .../docker/scripts/setup_rocm_yum_repo.sh | 43 - 62 files changed, 1 insertion(+), 8405 deletions(-) delete mode 100644 dockerfiles/Dockerfile.rocm delete mode 100644 dockerfiles/scripts/install_rocm_deps.sh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/attention.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/attention.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_impl.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_impl.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_softmax.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/group_query_attention.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/multihead_attention.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/transformer_common.cc delete mode 100644 onnxruntime/contrib_ops/rocm/bert/transformer_common.h delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc delete mode 100644 onnxruntime/contrib_ops/rocm/fused_conv.cc delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8.cu delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu delete mode 100644 onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc delete mode 100644 onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h delete mode 100755 tools/ci_build/github/linux/build_rocm_c_api_package.sh delete mode 100755 tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm deleted file mode 100644 index aca8c3feaff71..0000000000000 --- a/dockerfiles/Dockerfile.rocm +++ /dev/null @@ -1,24 +0,0 @@ -# -------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------- -# Dockerfile to run ONNXRuntime with ROCm integration -#-------------------------------------------------------------------------- - -FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 - -ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime -ARG ONNXRUNTIME_BRANCH=main - -WORKDIR /code - -ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} - -# Prepare onnxruntime repository & build onnxruntime -RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ - /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ - cd onnxruntime &&\ - /bin/sh ./build.sh --allow_running_as_root --config Release --build_wheel --update --build --parallel --cmake_extra_defines\ - ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm &&\ - pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ - cd .. diff --git a/dockerfiles/README.md b/dockerfiles/README.md index 4c69098103edd..88c542b63ccd2 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -1,9 +1,8 @@ # Dockerfiles **Execution Providers** - CPU: [Dockerfile](Dockerfile.source), [Instructions](#cpu) -- CUDA/cuDNN: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) +- CUDA: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) - MIGraphX: [Dockerfile](Dockerfile.migraphx), [Instructions](#migraphx) -- ROCm: [Dockerfile](Dockerfile.rocm), [Instructions](#rocm) - OpenVINO: [Dockerfile](Dockerfile.openvino), [Instructions](#openvino) - TensorRT: [Dockerfile](Dockerfile.tensorrt), [Instructions](#tensorrt) - VitisAI: [Dockerfile](Dockerfile.vitisai) @@ -304,17 +303,3 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-migraphx ``` - - ## ROCm -**Ubuntu 22.04, ROCm6.2.3** - -1. Build the docker image from the Dockerfile in this repository. - ``` - docker build -t onnxruntime-rocm -f Dockerfile.rocm . - ``` - -2. Run the Docker image - - ``` - docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-rocm - ``` diff --git a/dockerfiles/scripts/install_rocm_deps.sh b/dockerfiles/scripts/install_rocm_deps.sh deleted file mode 100644 index fd445be87479b..0000000000000 --- a/dockerfiles/scripts/install_rocm_deps.sh +++ /dev/null @@ -1,84 +0,0 @@ -#!/bin/bash -prefix=/opt/rocm -DEBIAN_FRONTEND=noninteractive -apt-get update && apt-get install -y --no-install-recommends \ - wget \ - zip \ - ca-certificates \ - build-essential \ - curl \ - libcurl4-openssl-dev \ - libssl-dev \ - python3-dev - -# rocm-cmake -rocm_cmake_version=4.5.2 -wget --quiet https://github.com/RadeonOpenCompute/rocm-cmake/archive/refs/tags/rocm-${rocm_cmake_version}.tar.gz -tar -xzvf rocm-${rocm_cmake_version}.tar.gz -rm rocm-${rocm_cmake_version}.tar.gz -cd rocm-cmake-rocm-${rocm_cmake_version} -mkdir build -cd build -cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocm-cmake-rocm-${rocm_cmake_version} - -# rccl -rccl_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rccl/archive/refs/tags/rocm-${rccl_version}.tar.gz -tar -xzvf rocm-${rccl_version}.tar.gz -rm rocm-${rccl_version}.tar.gz -cd rccl-rocm-${rccl_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rccl-rocm-${rccl_version} - -#rocrand -rocrand_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rocRAND/archive/refs/tags/rocm-${rocrand_version}.tar.gz -tar -xzvf rocm-${rocrand_version}.tar.gz -rm rocm-${rocrand_version}.tar.gz -cd rocRAND-rocm-${rocrand_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocRAND-rocm-${rocrand_version} - -#hipcub -hipcub_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/hipCUB/archive/refs/tags/rocm-${hipcub_version}.tar.gz -tar -xzvf rocm-${hipcub_version}.tar.gz -rm rocm-${hipcub_version}.tar.gz -cd hipCUB-rocm-${hipcub_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make package -make install -cd ../.. -rm -rf hipCUB-rocm-${hipcub_version} - -#rocprim -rocprim_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/refs/tags/rocm-${rocprim_version}.tar.gz -tar -xzvf rocm-${rocprim_version}.tar.gz -rm rocm-${rocprim_version}.tar.gz -cd rocPRIM-rocm-${rocprim_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocPRIM-rocm-${rocprim_version} - diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu deleted file mode 100644 index b40fc2bf0eef8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cu +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/attention.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" -#include "contrib_ops/rocm/bert/transformer_common.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr int kPastSequenceLengthInputIndex = 6; -constexpr int kPastInputIndex = 4; -constexpr int kPresentOutputIndex = 1; - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Attention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(kPastInputIndex, kPresentOutputIndex) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ - Attention); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -template -Attention::Attention(const OpKernelInfo& info) - : RocmKernel(info), AttentionBase(info, true), attn_type_(kAttention) { - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - tunable_op_ = std::make_shared(); -} - -template -Status Attention::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* weights = context->Input(1); - const Tensor* bias = context->Input(2); - const Tensor* mask_index = context->Input(3); - const Tensor* past = context->Input(4); - const Tensor* attention_bias = context->Input(5); - const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); - - auto& device_prop = GetDeviceProp(); - RocmAttentionParameters attn; - ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), - weights->Shape(), - bias->Shape(), - mask_index, - past, - attention_bias, - &attn, - device_prop.maxThreadsPerBlock, - past_seq_len)); - ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention - ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(attn.batch_size); - output_shape[1] = static_cast(attn.sequence_length); - output_shape[2] = static_cast(attn.v_hidden_size); - Tensor* output = context->Output(0, output_shape); - - std::vector present_dims{ - 2, attn.batch_size, attn.num_heads, - past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, - attn.head_size}; - TensorShape present_shape(present_dims); - Tensor* present = context->Output(kPresentOutputIndex, present_shape); - - auto stream = Stream(context); - hipblasHandle_t hipblas = GetHipblasHandle(context); - - using HipT = typename ToHipType::MappedType; - using QkvProjectGeneric = GemmPermuteGenericPipeline; - using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - - ORT_RETURN_IF_ERROR(ClassifyAttentionMode(attn_type_, &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present})); - ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE || - attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE || - attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE || - attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE || - attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE); - - size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn); - size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn), - AttentionGeneric::GetWorkspaceNumBytes(&attn)); - if (GetTuningContext()->IsTunableOpEnabled()) { - shared_workspace_bytes = std::max(shared_workspace_bytes, AttentionTunableOp::GetWorkspaceNumBytes(&attn)); - } - - auto qkv_project_output = GetScratchBuffer(qkv_project_output_bytes, context->GetComputeStream()); - auto workspace = GetScratchBuffer(shared_workspace_bytes, context->GetComputeStream()); - - GemmPermuteParams gemm_permute_params; - { - auto& params = gemm_permute_params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = hipblas; - params.attention = &attn; - params.device_prop = &device_prop; - - params.input_buffer = reinterpret_cast(input->DataRaw()); - params.weight_buffer = reinterpret_cast(weights->DataRaw()); - params.bias_buffer = reinterpret_cast(bias->DataRaw()); - params.out_buffer = reinterpret_cast(qkv_project_output.get()); - params.ones = GetConstOnes(attn.batch_size * attn.sequence_length, stream); - params.workspace_buffer = reinterpret_cast(workspace.get()); - } - - ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params)); - auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params); - - // NOTE: GemmPermute always output 3BNSH, k_buffer and v_buffer can be treated as 2BNSH - if (nullptr != present) { - Strides dst_strides; // the output buffer is present Tensor, the buffer is the same - - int4 add_shape{2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size}; - HipT* add_dest = nullptr; // destination of concatenated data to present - const HipT* const add_src = k_buffer; // source of concatenated data to present - const auto add_src_strides = Strides::BNSHMemory( - 2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size); - - if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; - } else if (attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - - // We only need to copy past to present in this case. All other cases will be build the present incrementally - const int4 past_shape = {2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; - HipT* const past_dest = reinterpret_cast(present->MutableDataRaw()); - const HipT* const past_src = reinterpret_cast(past->DataRaw()); - const Strides past_src_strides = Strides::BNSHMemory( - 2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); - - ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, past_src, past_shape, past_src_strides.ForBNSHCoord(), - past_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } else if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; - } else if (attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - } - - ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, add_src, add_shape, add_src_strides.ForBNSHCoord(), - add_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - - // update pointers to present_k and present_v. TODO: switch to ConvertToOffsetedBufferViews - k_buffer = reinterpret_cast(present->MutableDataRaw()); - v_buffer = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(attn.batch_size, 0, 0, 0); - } - - // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax - const TransformerOptions* options = TransformerOptions::GetInstance(); - bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); - - GemmSoftmaxGemmPermuteParams gemm_softmax_gemm_permute_params; - { - auto& params = gemm_softmax_gemm_permute_params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = hipblas; - params.attention = &attn; - params.device_prop = &device_prop; - // FIXME: the params.scale seems to be different from AttentionParameters::scale; - params.scale = 1.0f / sqrt(static_cast(attn.head_size)); - // TODO: switch to ConvertToOffsetedBufferViews - params.q_buffer = q_buffer; - params.k_buffer = k_buffer; - params.v_buffer = v_buffer; - params.out_buffer = reinterpret_cast(output->MutableDataRaw()); - - if (attention_bias != nullptr) { - params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); - } - - if (mask_index != nullptr) { - params.mask_index_buffer = mask_index->Data(); - params.mask_index_dims = mask_index->Shape().AsShapeVector(); - } - - params.workspace_buffer = reinterpret_cast(workspace.get()); - } - - if (this->GetTuningContext()->IsTunableOpEnabled() && - !use_persistent_softmax) { - return (*std::static_pointer_cast(tunable_op_))(&gemm_softmax_gemm_permute_params); - } else { - return AttentionGeneric::Run(&gemm_softmax_gemm_permute_params, use_persistent_softmax); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.h b/onnxruntime/contrib_ops/rocm/bert/attention.h deleted file mode 100644 index 7204fd660a516..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class Attention final : public RocmKernel, public AttentionBase { - public: - Attention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - public: - AttentionType attn_type_; - - // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: - // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. - // 2. We don't want to construct the object repeatly (which is expansive) during Compute. - std::shared_ptr tunable_op_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu deleted file mode 100644 index 270a8e51daf88..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ /dev/null @@ -1,435 +0,0 @@ -/* - The implementation of this file is based on qkvToContext plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Modifications: scaling is moved from masked softmax to the gemm before that. -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/attention_softmax.h" -#include "contrib_ops/rocm/bert/decoder_attention_impl.h" - -using namespace onnxruntime::rocm; - -namespace blas = onnxruntime::rocm::tunable::blas; - -#define CHECK_ROCM(expr) HIP_RETURN_IF_ERROR(expr) - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -static size_t AlignTo(size_t a, size_t b) { - return CeilDiv(a, b) * b; -} - -size_t GetAttentionScratchSize(size_t element_size, - int batch_size, - int num_heads, - int sequence_length, - int total_sequence_length) { - const size_t bytes = element_size * batch_size * num_heads * sequence_length * total_sequence_length; - - const size_t alignment = 256; - const size_t bytesAligned = AlignTo(bytes, alignment); - return bytesAligned; -} - -size_t GetAttentionWorkspaceSize( - size_t element_size, - int batch_size, - int num_heads, - int head_size, - int sequence_length, - int total_sequence_length) { - size_t qkv_size = element_size * 3 * batch_size * sequence_length * num_heads * head_size; - return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, - sequence_length, total_sequence_length); -} - -inline int3 Get2DMaskStrides(int total_sequence_length) { - // stride == 0 indicate broadcasting - return {total_sequence_length, 0, 1}; -} - -Status ClassifyAttentionMode( - AttentionType attn_type, - RocmAttentionParameters* attn, - const std::vector& qkv, - const std::vector& past, - const std::vector& present) { - size_t num_qkv = std::count_if(qkv.cbegin(), qkv.cend(), [](auto it) { return it != nullptr; }); - size_t num_past = std::count_if(past.cbegin(), past.cend(), [](auto it) { return it != nullptr; }); - size_t num_present = std::count_if(present.cbegin(), present.cend(), [](auto it) { return it != nullptr; }); - - auto hint = MakeString(num_qkv, " qkv inputs, ", num_past, " past inputs and ", num_present, " present inputs"); - LOGS_DEFAULT(VERBOSE) << hint; - - if (attn_type == kAttention) { - ORT_ENFORCE(num_qkv == 0); - if (num_past == 0 && num_present == 0) { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE; - return Status::OK(); - } else if (num_past == 0 && num_present == 1) { - if (attn->past_present_share_buffer == false) { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE; - return Status::OK(); - } else { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE; - return Status::OK(); - } - } else if (num_past == 1 && num_present == 1) { - if (attn->past_present_share_buffer == false) { - attn->mode = QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE; - return Status::OK(); - } else { - attn->mode = QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE; - return Status::OK(); - } - } - } else if (attn_type == kMultiHeadAttention || attn_type == kDecoderMaskedMultiHeadAttention) { - if (num_qkv == 3 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } else if (num_qkv == 3 && num_past == 0 && num_present == 2) { - if (attn->past_present_share_buffer == false) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; - return Status::OK(); - } - } else { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; - return Status::OK(); - } - } - } else if (num_qkv == 3 && num_past == 2 && num_present == 2) { - if (attn->past_present_share_buffer == false) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; - return Status::OK(); - } - } else { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; - return Status::OK(); - } - } - } else if (num_qkv == 1 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == QKV_BSN3H) { - attn->mode = BLN3H_NONE_NONE_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } else if (num_qkv == 2 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == Q_KV_BSNH_BSN2H) { - attn->mode = BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } - } - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Unsupported AttentionMode for ", attn_type, ". Got qkv format ", attn->qkv_format, - ". Got ", hint); -} - -template -Status DecoderQkvToContext( - const hipDeviceProp_t& prop, - RocmTuningContext* tuning_ctx, - Stream* ort_stream, - hipblasHandle_t& hipblas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) { - const int max_threads_per_block = prop.maxThreadsPerBlock; - const int BN = batch_size * num_heads; - const int BHN = BN * head_size; - const int BNS = BN * sequence_length; - const int k_buffer_offset = sequence_length * BHN; - const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; - - T* temp_qkv_buffer = workspace_buffer; - auto stream = static_cast(ort_stream->GetHandle()); - - const T* q = qkv_buffer; - // transpose q and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, - num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); - - const T* k = qkv_buffer + k_buffer_offset; - const T* v = qkv_buffer + v_buffer_offset; - if (!has_layer_state || !use_past) { - if (!static_kv) { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } else { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } - } else { - if (!static_kv) { - // transpose kv and copy them to temp_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); - // concat cache-k with k and copy to qkv_buffer - if (nullptr != key_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, - batch_size, head_size, num_heads, - max_threads_per_block, 1, key_cache, - temp_qkv_buffer, qkv_buffer + k_buffer_offset)); - } - // concat cache-v with v and copy to qkv_buffer - if (nullptr != value_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, - batch_size, head_size, num_heads, - max_threads_per_block, 1, value_cache, - temp_qkv_buffer + k_buffer_offset, - qkv_buffer + v_buffer_offset)); - } - } - } - - if (has_layer_state) { - if (use_past && static_kv) { - CHECK_ROCM(hipMemcpyAsync(new_key_cache, key_cache, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - CHECK_ROCM(hipMemcpyAsync(new_value_cache, value_cache, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } else { - CHECK_ROCM(hipMemcpyAsync(new_key_cache, k, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - CHECK_ROCM(hipMemcpyAsync(new_value_cache, v, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } - } - - // scratch1: BxNxSxS* buffer - // scratch2: BxNxSxS* buffer - // scratch3: BxNxSxH buffer - T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; - T* scratch2 = scratch1 + BNS * kv_sequence_length; - T* scratch3 = scratch2 + BNS * kv_sequence_length; - - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* - // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int temp_matrix_size = sequence_length * kv_sequence_length; - - const int strideA = kv_sequence_length * head_size; - const int strideB = sequence_length * head_size; - if (use_past && static_kv) { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::Trans, blas::BlasOp::NonTrans, - kv_sequence_length, sequence_length, head_size, - /*alpha=*/rsqrt_head_size, - key_cache, head_size, strideA, - q, head_size, strideB, - /*beta=*/0.0f, - scratch1, kv_sequence_length, temp_matrix_size, - BN)); - } else { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::Trans, blas::BlasOp::NonTrans, - kv_sequence_length, sequence_length, head_size, - /*alpha=*/rsqrt_head_size, - k, head_size, strideA, - q, head_size, strideB, - /*beta=*/0.0f, - scratch1, kv_sequence_length, temp_matrix_size, - BN)); - } - - if (has_key_padding_mask) { - int3 strides = Get2DMaskStrides(kv_sequence_length); - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( - ort_stream, kv_sequence_length, sequence_length, batch_size, num_heads, - strides, nullptr, key_padding_mask, nullptr, scratch1, scratch2, - false, 1.0f, false, nullptr, mask_filter_value)); - } else { - ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, scratch1, scratch2, false)); - } - - // compute P*V (as V*P), and store in scratch3: BxNxSxH - if (use_past && static_kv) { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - head_size, sequence_length, kv_sequence_length, - /*alpha=*/1.0f, - value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - /*beta=*/0.0f, - scratch3, head_size, strideB, - BN)); - } else { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - head_size, sequence_length, kv_sequence_length, - /*alpha=*/1.0f, - v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - /*beta=*/0.0f, - scratch3, head_size, strideB, - BN)); - } - - // scratch3 is BxNxSxH, transpose to output SxBxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, - num_heads, max_threads_per_block, true, scratch3, output); -} - -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, - RocmTuningContext* tuning_ctx, - Stream* stream, - hipblasHandle_t& hipblas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) { - if (element_size == 2) { - return DecoderQkvToContext( - prop, - tuning_ctx, - stream, - hipblas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } else { - return DecoderQkvToContext( - prop, - tuning_ctx, - stream, - hipblas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h deleted file mode 100644 index 07d875e90fa4b..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "contrib_ops/cpu/bert/attention_common.h" -#include "contrib_ops/cpu/bert/attention_parameters.h" -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -typedef struct __align__(32) { - long long int x, y, z, w; -} LongLong4; - -size_t GetAttentionScratchSize( - size_t element_size, - int batch_size, - int num_heads, - int sequence_length, - int all_sequence_length); - -size_t GetAttentionWorkspaceSize( - size_t element_size, - int batch_size, - int num_heads, - int head_size, - int sequence_length, - int past_sequence_length); - -Status LaunchTransCtx(hipStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); - -Status LaunchTransCtx(hipStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const half* input, half* output); - -Status LaunchTransQkv(hipStream_t stream, const int matrix_num, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const float* input, float* output, - int total_matrix_count = -1); - -Status LaunchTransQkv(hipStream_t stream, const int matrix_num, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, - int total_matrix_count = -1); - -Status LaunchConcatTensorToTensor(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const float* tensor_in, - const float* tensor_add, - float* tensor_out); - -Status LaunchConcatTensorToTensor(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const half* tensor_in, - const half* tensor_add, - half* tensor_out); - -inline hipblasStatus_t _compat_hipblas_gemm_strided_batched_ex(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const void* alpha, - const void* A, - hipDataType a_type, - int lda, - hipblasStride stride_A, - const void* b, - hipDataType b_type, - int ldb, - hipblasStride stride_b, - const void* beta, - void* c, - hipDataType c_type, - int ldc, - hipblasStride stride_c, - int batch_count, - hipblasComputeType_t compute_type, - hipblasGemmAlgo_t algo) { - return hipblasGemmStridedBatchedEx(handle, - transa, - transb, - m, // m - n, // n - k, // k - alpha, // alpha - A, // A - a_type, // A type - lda, // lda - stride_A, // strideA - b, // B - b_type, // B type - ldb, // ldb - stride_b, // strideB - beta, // beta - c, // C - c_type, // C type - ldc, // ldc - stride_c, // strideC - batch_count, // batch count - compute_type, - algo); -} - -// Compatible for CublasMathModeSetter -class CompatHipblasMathModeSetter { - public: - CompatHipblasMathModeSetter(const hipDeviceProp_t&, - hipblasHandle_t, - int) { - } -}; - -enum AttentionMode { - // Q,K,V,PastK,PastV,PresentK,PresentV - QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, - QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE, - QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE, - QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE, - QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE, - BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE, - BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE, - BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH, - BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH, - BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH, - BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH, - BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH, - BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH, - BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH, - BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH, - BLN3H_NONE_NONE_NONE_NONE_NONE_NONE, - BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE, -}; - -struct RocmAttentionParameters : AttentionParameters { - AttentionMode mode; -}; - -Status ClassifyAttentionMode(AttentionType type, - RocmAttentionParameters* attn, - const std::vector& qkv, - const std::vector& past, - const std::vector& present); - -template -Status LaunchStridedCopy( - hipStream_t stream, - const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) - T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) - int max_threads_per_block); - -template -Status LaunchStridedCopy(hipStream_t stream, - const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h) - T* out, LongLong4 out_strides, // coord (b,n,s,h) - int max_threads_per_block); -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h deleted file mode 100644 index 9f2faa228cf79..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ /dev/null @@ -1,465 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on qkvToContext plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#pragma once - -#include -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/math/softmax.h" - -#define ROCMRT_INF_F __int_as_float(0x7f800000) - -using namespace onnxruntime::rocm; -using namespace hipcub; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -__device__ inline void Softmax(const int all_sequence_length, - const int valid_end, - const int valid_start, - const T* attn_bias, - const T* input, - T* output) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - float thread_data_max(-ROCMRT_INF_F); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - for (int i = threadIdx.x; i < valid_end; i += TPB) { - if (i >= valid_start) { - const int index = offset + i; - float input_at_idx = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - if (thread_data_max < input_at_idx) { - thread_data_max = input_at_idx; - } - } - } - - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max()); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_sum(0.f); - for (int i = threadIdx.x; i < valid_end; i += TPB) { - if (i >= valid_start) { - const int index = offset + i; - float val = attn_bias == nullptr ? input[index] : input[index] + attn_bias[index]; - thread_data_sum += expf(val - max_block); - } - } - - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, hipcub::Sum()); - if (threadIdx.x == 0) { - sum_reverse_block = 1.f / sum; - } - __syncthreads(); - - for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { - const int index = offset + i; - float input_at_idx = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; - output[index] = T(val); - } -} - -template -__device__ inline void SoftmaxSmall(const int all_sequence_length, - const int sequence_length, - const int valid_end, - const int valid_start, - const T* attn_bias, - const T* input, - T* output, - bool causal) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - const int index = offset + threadIdx.x; - - bool is_valid = false; // whether it has attention mask == 1. - - // Update end position for causal. - int end = valid_end; - if (causal) { - const int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; - if (end_causal < end) { - end = end_causal; - } - } - - is_valid = (threadIdx.x >= valid_start && threadIdx.x < end); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float input_data = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F); - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_exp(0.f); - if (is_valid) { - thread_data_exp = expf(input_data - max_block); - } - - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), end); - - // Store value of 1.0/sum. - if (threadIdx.x == 0) { - sum_reverse_block = (1.f) / sum; - } - __syncthreads(); - - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. - if (threadIdx.x < all_sequence_length) { - output[index] = is_valid ? T(thread_data_exp * sum_reverse_block) : T(0.f); - } -} - -// Note about the attention_mask_strides and attention_mask/key_padding_mask -// attention_mask accepts 2D, 3D or 4D tensor, but it will be viewed as 3D tensor uniformally and it will be indexed -// as [batch_index, sequence_index, token_index]. -template -__global__ void SoftmaxWithRawMaskSmallKernel( - const int all_sequence_length, - const int sequence_length, - const int3 attention_mask_strides, - const int* attention_mask, // 2D, 3D or 4D attention mask - const bool* key_padding_mask, - const T* attn_bias, - const T* input, - T* output, - const bool causal, - const float rsqrt_head_size, - const bool skip_softmax, - const float mask_filter_value) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; - - // Mask all thread_data values to negative infinity to allow BlockReduce Max operation over all thread_data - // members with all invalid members set to a value that does not impact the final result. This is necessary - // to avoid the performance impact from using the valid_items interface. - float thread_data = -ROCMRT_INF_F; - if (threadIdx.x < all_sequence_length) { - thread_data = float(input[index]) * rsqrt_head_size; - - const int sequence_index = blockIdx.x % sequence_length; - if (causal) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. - if (threadIdx.x > from_index) { - thread_data = -ROCMRT_INF_F; - } - } - - const int batch_index = blockIdx.y; - int mask_offset = attention_mask_strides.x * batch_index + - attention_mask_strides.y * sequence_index + - attention_mask_strides.z * threadIdx.x; - - if (nullptr == key_padding_mask) { - const int& mask = attention_mask[mask_offset]; - if (mask == 0) - thread_data += mask_filter_value; - } else { - const bool mask = key_padding_mask[mask_offset]; - if (mask) { - thread_data = -ROCMRT_INF_F; - } - } - - if (attn_bias != nullptr) { - thread_data += float(attn_bias[index]); - } - } - - if (skip_softmax) { - if (threadIdx.x < all_sequence_length) { - output[index] = T(thread_data); - } - return; - } - - const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max()); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - // Mask all thread_data_exp values to zero to allow BlockReduce Sum operation over all thread_data_exp - // members with all invalid members set to a value that does not impact the final result. This is necessary - // to avoid the performance impact from using the valid_items interface. - float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum()); - - // Store value of 1.0/sum - if (threadIdx.x == 0) { - sum_reverse_block = (1.f) / sum; - } - __syncthreads(); - - if (threadIdx.x < all_sequence_length) { - output[index] = T(thread_data_exp * sum_reverse_block); - } -} - -template -__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, - const T* attn_bias, const T* input, T* output, bool causal) { - SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, - attn_bias, input, output, causal); -} - -template -__global__ void SoftmaxKernel(const int all_sequence_length, const T* attn_bias, const T* input, T* output) { - Softmax(all_sequence_length, all_sequence_length, 0, attn_bias, input, output); -} - -template -Status ComputeSoftmax( - hipStream_t stream, - const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const T* attn_bias, const T* input, T* output, bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - if (all_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (!causal) { - const int blockSize = 1024; - SoftmaxKernel<<>>( - all_sequence_length, attn_bias, input, output); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - - return HIP_CALL(hipPeekAtLastError()); -} - -template -__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, - const int* mask_end, const int* mask_start, - const T* attn_bias, const T* input, T* output, - bool causal) { - __shared__ int start_position; - __shared__ int end_position; - - if (threadIdx.x == 0) { - const int batch = blockIdx.y; - start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); - - // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. - if (start_position >= end_position) { - start_position = 0; - end_position = all_sequence_length; - } - } - __syncthreads(); - - SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, - attn_bias, input, output, causal); -} - -template -__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int* mask_end, const int* mask_start, - const T* attn_bias, const T* input, T* output) { - __shared__ int start_position; - __shared__ int end_position; - - if (threadIdx.x == 0) { - const int batch = blockIdx.y; - start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); - - // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. - if (start_position >= end_position) { - start_position = 0; - end_position = all_sequence_length; - } - } - __syncthreads(); - - Softmax(all_sequence_length, end_position, start_position, attn_bias, input, output); -} - -template -Status ComputeSoftmaxWithMask1D( - hipStream_t stream, - const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* mask_index, const int* mask_start, - const T* attn_bias, const T* input, T* output, const bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - -#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ - MaskedSoftmaxKernelSmall<<>>( \ - all_sequence_length, sequence_length, mask_index, mask_start, \ - attn_bias, input, output, causal); - - if (all_sequence_length <= 32) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); - } else if (all_sequence_length <= 64) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); - } else if (all_sequence_length <= 128) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); - } else if (all_sequence_length <= 256) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); - } else if (all_sequence_length <= 512) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); - } else if (all_sequence_length <= 1024) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); - } else if (!causal) { - const int blockSize = 1024; - MaskedSoftmaxKernel<<>>( - all_sequence_length, mask_index, mask_start, - attn_bias, input, output); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - -#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE - - return HIP_CALL(hipPeekAtLastError()); -} - -template -Status ComputeSoftmaxWithRawMask(Stream* ort_stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int num_heads, - const int3 attention_mask_strides, - const int* attention_mask, - const bool* key_padding_mask, - const T* attn_bias, - const T* input, - T* output, - const bool causal, - const float rsqrt_head_size, - const bool use_persistent_softmax, - T* persistent_softmax_workspace, - const float mask_filter_value) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - - T* out = use_persistent_softmax ? persistent_softmax_workspace : output; - auto stream = static_cast(ort_stream->GetHandle()); - -#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ - SoftmaxWithRawMaskSmallKernel<<>>( \ - all_sequence_length, sequence_length, attention_mask_strides, \ - attention_mask, key_padding_mask, attn_bias, input, out, \ - causal, rsqrt_head_size, \ - use_persistent_softmax, mask_filter_value); - - if (all_sequence_length <= 32) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); - } else if (all_sequence_length <= 64) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); - } else if (all_sequence_length <= 128) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); - } else if (all_sequence_length <= 256) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); - } else if (all_sequence_length <= 512) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); - } else if (all_sequence_length <= 1024) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - -#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE - - if (use_persistent_softmax) { - return dispatch_warpwise_softmax_forward(ort_stream, - output, - persistent_softmax_workspace, - all_sequence_length, - all_sequence_length, - batch_size * num_heads * sequence_length); - } - - return HIP_CALL(hipPeekAtLastError()); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh deleted file mode 100644 index 213940f132963..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_common.h" -#include "contrib_ops/cpu/bert/attention_parameters.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace blas = onnxruntime::rocm::tunable::blas; - -namespace { -std::tuple GetQkvProjectGemmMNKBatch(const AttentionParameters* attention) { - int m = attention->sequence_length; - int n = (attention->hidden_size + attention->hidden_size + attention->v_hidden_size); // q + k + v - int k = attention->input_hidden_size; - int batch = attention->batch_size; - return {m, n, k, batch}; -} -} // namespace - -template -struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams { - std::string Signature() const override { - auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(attention); - return MakeString("M", m, "_N", n, "_K", k, "_B", batch); - } - - hipblasHandle_t handle; - const AttentionParameters* attention; - const hipDeviceProp_t* device_prop; - - const T* input_buffer; - const T* weight_buffer; - const T* bias_buffer; - T* out_buffer; - - int3 bias_strides; - - const T* ones; // used for broadcasting bias if the underlying algorithm does not support strides - T* workspace_buffer; -}; - -template -struct GemmPermuteGenericPipeline { - inline static size_t GetOutputNumBytes(const AttentionParameters* attn) { - auto [m, n, _, batch] = GetQkvProjectGemmMNKBatch(attn); - return sizeof(T) * m * n * batch; - } - - inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) { - return GetOutputNumBytes(attn); - } - - inline static std::tuple GetGemmMNK(const GemmPermuteParams* params) { - auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(params->attention); - return {batch * m, n, k}; - } - - inline static std::tuple UnspliceOutputQKV(const GemmPermuteParams* params) { - auto* attn = params->attention; - int64_t batch = attn->batch_size * attn->num_heads; - int64_t num_elems_per_batch = attn->sequence_length * attn->head_size; - int64_t num_elems = batch * num_elems_per_batch; - auto q = params->out_buffer + 0 * num_elems; - auto k = params->out_buffer + 1 * num_elems; - auto v = params->out_buffer + 2 * num_elems; - return {q, k, v}; - } - - inline static Status BroadcastBias(const GemmPermuteParams* params) { - auto [m, n, k] = GetGemmMNK(params); - // Bias shape is (N), broadcast using B(M, N) = ones(M, 1) x bias(1, N). - // TODO: use custom kernel of expand to improve the performance. - return blas::row_major::Gemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, n, 1, - /*alpha=*/1.0f, - params->ones, 1, - params->bias_buffer, n, - /*beta=*/0.0f, - params->workspace_buffer, n); - } - - inline static Status Gemm(const GemmPermuteParams* params) { - auto [m, n, k] = GetGemmMNK(params); - // result(M, N) = input x weights + bias. - return blas::row_major::Gemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, n, k, - /*alpha=*/1.0f, - params->input_buffer, k, - params->weight_buffer, n, - /*beta=*/1.0f, - params->workspace_buffer, n); - } - - inline static Status Permute0213(const GemmPermuteParams* params) { - auto* attn = params->attention; - // input should be BxSx3xNxH => gemm_buffer: 3xBxNxSxH - return LaunchTransQkv( - params->StreamHandle(), 3, attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, - params->device_prop->maxThreadsPerBlock, false, params->workspace_buffer, params->out_buffer); - } - - static Status Run(const GemmPermuteParams* params) { - ORT_RETURN_IF_ERROR(BroadcastBias(params)); - ORT_RETURN_IF_ERROR(Gemm(params)); - ORT_RETURN_IF_ERROR(Permute0213(params)); - return Status::OK(); - } -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh deleted file mode 100644 index be8508670e4b1..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh +++ /dev/null @@ -1,177 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#ifdef USE_COMPOSABLE_KERNEL -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/utility/data_type.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using F16 = ck::half_t; -using F32 = float; - -template -using S = ck::Sequence; - -using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecialization; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; - -static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; - -template -using device_batched_gemm_softmax_gemm_permute_instances = - std::tuple< - // clang-format off - // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias| - // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar| - // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector| - // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, -#if ROCM_VERSION >= 50500 - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, -#endif - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, - // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> - // clang-format on - >; - -struct PreSoftmaxAttentionScoreOp { - PreSoftmaxAttentionScoreOp(float scale) : scale_(scale) {} - - // non-biased, non-masked - __host__ __device__ void operator()(float& y, const float& x) const { - y = scale_ * x; - } - - // biased or converted masked - __host__ __device__ void operator()(float& y, const float& x, const F16& bias) const { - y = scale_ * x + ck::type_convert(bias); - } - - // biased and converted masked - __host__ __device__ void operator()(float& y, const float& x, const F16& bias, const F16& converted_mask) const { - y = scale_ * x + ck::type_convert(bias) + ck::type_convert(converted_mask); - } - - float scale_; -}; - -// Use this function to gat implementation -template -std::vector, - PassThrough, PassThrough, D0Op, PassThrough, PassThrough, - MaskingSpec>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances() { - return {}; -} - -// implemented in impl_{fp16,bf16}[_biased][_masked].cu -// fp16, non-biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -// fp16, biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -// fp16, biased, fp16 masked, basically, two bias -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -// fp16, biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -// fp16, biased, fp16 masked, basically, two bias -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu deleted file mode 100644 index 2e32a6594d164..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using NonBiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu deleted file mode 100644 index 91da8d9e1f9a8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu deleted file mode 100644 index b08123be18977..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh deleted file mode 100644 index 226b89cfb2b86..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ /dev/null @@ -1,915 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -/* About Computing in these Pipelines - -B: batch size of Attention Op. NOTE: To be disambiguated with batch size of GEMMs -S: sequence length -T: total sequence length -N: num of heads -H: head dimension - -The following use qkv_format == Q_K_V_BNSH (mode == BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE) as a example: - -BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op - -In QKV projection (prior to this pipeline): - /-> Q [B,S,N*H] ->Reshape-> [B,S,N,H] ->Permute0213-> [B,N,S,H] -X --o--> K [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] - \-> V [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] - -pre_softmax_attn_scores = Q*K' = [B,N,S,H] * [BxNxTxH]' = [B,N,S,T] Batched GEMM1 -pre_softmax_attn_scores_masked = pre_softmax_attn_scores * scale +? bias +? mask Scale Add Bias, +? is optional -attn_scores = softmax(pre_softmax_attn_scores_masked) = [B,N,S,T] Softmax -scaled_multi_head_attn = attn_scores * V = [B,N,S,T] * [B,N,T,H] = [B,N,S,H] Batched GEMM2 - -Op outputs scaled_multi_head_attn: -[B,N,S,H] ->Permute0213-> [B,S,N,H] ->Reshape-> [B,S,N*H] - - -For the computing of pre_softmax_attn_scores +? mask +? bias: - -GemmSoftmaxGemmPermuteGenericPipeline handles it in specialized softmax. TODO: remove it! - -CK in GemmSoftmaxGemmPermuteTunablePipeline - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_2d ---> [B,T] ---> [B,1,1,T] -/ - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_3d --> [B,S,T] --> [B,1,S,T] -/ - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_4d -> [B,1,M,M] -> [B,1,S,T] -/ M is max_sequence_length from megatron, we will create a - **sub-view** from original mask buffer - -For CK implementation, there will be four cases combined: -non-biased, non-masked, no special processing. - biased, non-masked, no special processing, add the mask directly. -non-biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with scaled Q*K'. - biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with bias and scaled Q*K'. - -Broadcast add is not actually perform the broadcasting, just broadcast the load operation from memory. The impl details -are in composable kernels. The scale and add logic is performed via Acc0ElementOp - -# Classified modes: - -| Q | K | V | past(K)| pastV | present(K)| presentV | Op, desc -| ---- | ---- | ---- | ------ | ----- | --------- | -------- | --------- -| QFMT | KFMT | VFMT | - | - | - | - | A, basic, qkv is impl dependent by qkv_format -| QFMT | KFMT | VFMT | 2BNPH | - | 2BNTH *^ | - | A, past_present_share_buffer = false, qkv is impl dependent by qkv_format -| QFMT | KFMT | VFMT | 2BNMH | - | 2BNMH *^ | - | A, past_present_share_buffer = true, qkv is impl dependent by qkv_format -| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic -| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true -| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false -| BSNH | BLNH | BLNH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false -| BSNH | BNLH | BNLH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false -| BSNH | BLNH | BLNH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true -| BSNH | BNLH | BNLH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true -| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false -| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false -| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true -| BSNH | BNLH | BNLH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true -| BLN3H*^| - | - | - | - | - | - | MHA basic, qkv_packed -| BSNH | BLN2H*^| - | - | - | - | - | MHA basic, kv_packed - -Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel - -About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_v for v_buffer - -- Marked with `*` indicate the Tensor is used for k_buffer passing. -- Marked with `^` indicate the Tensor is used for v_buffer passing. - -# Supported Op - -- A: Attention -- MHA: MultiHeadAttention - -# Dim Value - -- B: batch_size -- N: num_heads -- H: head_size - -- S: sequence_length -- L: kv_sequence_length -- P: past_sequence_length -- T: total_sequence_length = P + L -- M: max_sequence_length -*/ - -#include "core/framework/tensor_shape.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/attention_softmax.h" -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif // USE_COMPOSABLE_KERNEL - -#include -#include - -namespace blas = onnxruntime::rocm::tunable::blas; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -inline int3 Get2DMaskStrides(int total_sequence_length) { - // stride == 0 indicate broadcasting - return {total_sequence_length, 0, 1}; -} - -// A stride maps from natural coordinate to physical offset of underlying memory storage buffer offset. We need to -// specify both of the natural coordinate order, say (b,n,s,h), (b,s,n,h) or (b,n,h,s), and memory order, say BNSH or -// BSNH, to determain the strides. To obtain the offset, we just do the inner product of coordinate with the strides. -// This wrapper create the stride vector from the physical dimension (or physical shape). -struct Strides { - // Create the strides for BNSH physically indexed memory buffer - static Strides BNSHMemory(int batch_dim, - int num_head_dim, - int seqlen_dim, - int head_size_dim) { - ORT_UNUSED_PARAMETER(batch_dim); - return Strides{LongLong4{ - static_cast(num_head_dim) * seqlen_dim * head_size_dim, - static_cast(seqlen_dim) * head_size_dim, - static_cast(head_size_dim), - static_cast(1), - }}; - } - - // Create the strides for BSNH physically indexed memory buffer - static Strides BSNHMemory(int batch_dim, - int seqlen_dim, - int num_head_dim, - int head_size_dim) { - ORT_UNUSED_PARAMETER(batch_dim); - return Strides{LongLong4{ - static_cast(seqlen_dim) * num_head_dim * head_size_dim, - static_cast(head_size_dim), - static_cast(num_head_dim) * head_size_dim, - static_cast(1), - }}; - } - - template - T ForBNSHCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.z), - static_cast(strides_for_bnsh_coord.w)}; - } - - template - T ForBSNHCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.z), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.w)}; - } - - template - T ForBNHSCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.w), - static_cast(strides_for_bnsh_coord.z)}; - } - - int64_t OffsetAt(int b, int n, int s, int h) const { - return strides_for_bnsh_coord.x * b + strides_for_bnsh_coord.y * n + - strides_for_bnsh_coord.z * s + strides_for_bnsh_coord.w * h; - } - - // store intermediate strides in the canonical (b,n,s,h) coordinate order - LongLong4 strides_for_bnsh_coord; -}; - -template -std::tuple ConvertToOffsetedBufferViews( - const RocmAttentionParameters* attn, - const T* query = nullptr, // q or packed_qkv - const T* key = nullptr, // k or packed kv - const T* value = nullptr, // - const T* present = nullptr, // present or present_k - const T* present_v = nullptr) { - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: { - return {reinterpret_cast(query), - reinterpret_cast(key), - reinterpret_cast(value)}; - } - case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: { - auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->total_sequence_length * - attn->head_size; - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present) + offset}; - } - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: { - auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->max_sequence_length * - attn->head_size; - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present) + offset}; - } - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present_v)}; - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: { - auto packed_kv = reinterpret_cast(key); - return {reinterpret_cast(query), packed_kv, packed_kv + attn->head_size}; - } - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: { - auto packed_qkv = reinterpret_cast(query); - return {packed_qkv, packed_qkv + 1 * attn->head_size, packed_qkv + 2 * attn->head_size}; - } - default: - ORT_ENFORCE("unreachable"); - return {}; - } -} - -inline std::tuple GetQkvStrides(const RocmAttentionParameters* attn) { - // G0 not used, because it is the slowest dimension - const int& B = attn->batch_size; - const int& N = attn->num_heads; - const int& S = attn->sequence_length; - const int& L = attn->kv_sequence_length; - const int& T = attn->total_sequence_length; - const int& M = attn->max_sequence_length; - const int& H = attn->head_size; - const int& Hv = attn->v_head_size; - - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - if (attn->qkv_format == Q_K_V_BNSH) { - return { - Strides::BNSHMemory(B, N, S, H), - Strides::BNSHMemory(B, N, L, H), - Strides::BNSHMemory(B, N, L, Hv), - }; - } else if (attn->qkv_format == Q_K_V_BSNH) { - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, H), - Strides::BSNHMemory(B, L, N, Hv), - }; - } - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, T, H), - Strides::BNSHMemory(B, N, T, Hv), - }; - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, M, H), - Strides::BNSHMemory(B, N, M, Hv), - }; - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, H), - Strides::BSNHMemory(B, L, N, Hv), - }; - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, L, H), - Strides::BNSHMemory(B, N, L, Hv), - }; - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, 2 * H), - Strides::BSNHMemory(B, L, N, 2 * Hv), - }; - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, L, N, 3 * H), - Strides::BSNHMemory(B, L, N, 3 * H), - Strides::BSNHMemory(B, L, N, 3 * Hv), - }; - default: - ORT_ENFORCE("unreachable"); - return {}; - } -} - -inline std::tuple GetRawMaskBufferAddrSizesAndStrides( - const int* buffer, const RocmAttentionParameters* attn) { - const int* offseted_buffer{buffer}; // how to view the mask buffer - int3 sizes{0, 0, 0}; // the logical shape of the view - int3 strides{-1, -1, -1}; // the physical memory layout - switch (attn->mask_type) { - case MASK_NONE: - case MASK_2D_DUMMY: - break; // No mask - case MASK_2D_KEY_PADDING: - sizes = {attn->batch_size, 1, attn->total_sequence_length}; - strides = Get2DMaskStrides(attn->total_sequence_length); - break; - case MASK_3D_ATTENTION: - sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; - strides = {attn->sequence_length * attn->total_sequence_length, attn->total_sequence_length, 1}; - break; - case MASK_4D_MEGATRON: - // offset to skip past sequence part, so that we can index it with [batch_index, sequence_index, token_index] - offseted_buffer = buffer + attn->past_sequence_length * attn->max_sequence_length; - sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; - strides = {attn->max_sequence_length * attn->max_sequence_length, attn->max_sequence_length, 1}; - break; - default: - LOGS_DEFAULT(FATAL) << "unsupported mask type: " << attn->mask_type; - throw std::runtime_error("unsupported mask type"); - } - return {offseted_buffer, sizes, strides}; -} - -template -struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { - std::string Signature() const override { - return MakeString( - "B", attention->batch_size, - "_S", attention->sequence_length, - "_T", attention->total_sequence_length, - "_N", attention->num_heads, - "_H", attention->head_size, - "_Hv", attention->v_head_size, - bias_buffer != nullptr ? "_B" : "_NB", - "_M", mask_index_dims.size(), - "_QKV", attention->qkv_format, - "_MODE", attention->mode); - } - - std::tuple GetGemmsMNKOBatch() const { - ORT_ENFORCE(attention != nullptr); - auto m = attention->sequence_length; - auto n = attention->total_sequence_length; - auto k = attention->head_size; - auto o = attention->v_head_size; - auto batch = attention->batch_size * attention->num_heads; - return {m, n, k, o, batch}; - } - - hipblasHandle_t handle; - const RocmAttentionParameters* attention; - const hipDeviceProp_t* device_prop; - - float scale; - const T* q_buffer; - const T* k_buffer; - const T* v_buffer; - T* out_buffer; - - // optional, attention bias [B,N,S,T] - // TODO: support shape [B,1,S,T], [1, N, S, T], [1, 1, S, T] with broadcast. - const T* bias_buffer{nullptr}; - - // optional, mask value - const int* mask_index_buffer{nullptr}; - TensorShapeVector mask_index_dims{}; - - // optional, internal - void* workspace_buffer{nullptr}; -}; - -inline bool IsKVBNMH(AttentionMode mode) { - switch (mode) { - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return true; - default: - return false; - } -} - -template -struct GemmSoftmaxGemmPermuteGenericPipeline { - static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams* params) { - return params->mask_index_buffer != nullptr && params->mask_index_dims.size() >= 2; - } - - static std::tuple GetWorkspacePlan(const GemmSoftmaxGemmPermuteParams* params) { - auto bytes = GetAttentionScratchSize( - sizeof(T), - params->attention->batch_size, - params->attention->num_heads, - params->attention->sequence_length, - params->attention->total_sequence_length); - auto gemm1_out = reinterpret_cast(params->workspace_buffer); - auto softmax_out = gemm1_out + (bytes / sizeof(T)); - auto gemm2_out = softmax_out + (bytes / sizeof(T)); - return {gemm1_out, softmax_out, gemm2_out}; - } - - inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { - return GetAttentionWorkspaceSize( - sizeof(T), - attn->batch_size, - attn->num_heads, - attn->head_size, - attn->sequence_length, - attn->total_sequence_length); - } - - inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams* params) { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - - int k_buffer_stride = n * k; - if (IsKVBNMH(params->attention->mode)) { - k_buffer_stride = params->attention->max_sequence_length * params->attention->head_size; - } - - // GEMM1 [m,k] * [n,k]' -> [m,n] - return blas::row_major::StridedBatchedGemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::Trans, - m, n, k, - // For raw attention mask, the scalar is moved to softmax computation. - /*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale, - params->q_buffer, k, m * k, - params->k_buffer, k, k_buffer_stride, - /*beta=*/0.0f, - gemm1_out, n, m * n, - batch); - } - - inline static Status SoftmaxRawMask(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { - // Softmax on [m,n] along the n dimension. - // Raw attention mask could be 2D (B,S) or 3D (B,S,T) or 4D(B,1,M,M), where M is the max sequence length. - auto attn = params->attention; - auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - T* persistent_softmax_workspace = gemm1_out; // replace Q*K' in place if persistent softmax is selected. - return ComputeSoftmaxWithRawMask( - params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - strides, buffer, nullptr, params->bias_buffer, gemm1_out, softmax_out, - attn->is_unidirectional, /* FIXME: this must not be attn.scale! */ params->scale, - use_persistent_softmax, persistent_softmax_workspace, attn->mask_filter_value); - } - - inline static Status Softmax1DIndexMask(const GemmSoftmaxGemmPermuteParams* params) { - auto mask_1d = params->mask_index_buffer; - auto mask_1d_size = params->mask_index_dims[0]; - // Softmax on [m,n] along the n dimension. - // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - const int* mask_start = (mask_1d_size > attn->batch_size) ? mask_1d + attn->batch_size : nullptr; - return ComputeSoftmaxWithMask1D( - params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - mask_1d, mask_start, params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); - } - - inline static Status SoftmaxNoMask(const GemmSoftmaxGemmPermuteParams* params) { - // Softmax on [m,n] along the n dimension. - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - return ComputeSoftmax( - params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); - } - - inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams* params) { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - - int v_buffer_stride = n * o; - if (IsKVBNMH(params->attention->mode)) { - v_buffer_stride = params->attention->max_sequence_length * params->attention->v_head_size; - } - - // GEMM2 [m,n] * [n,o] -> [m,o] - // semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H. - return blas::row_major::StridedBatchedGemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, o, n, - /*alpha=*/1.0f, - softmax_out, n, m * n, - params->v_buffer, o, v_buffer_stride, - /*beta=*/0.0f, - gemm2_out, o, m * o, - batch); - } - - inline static Status Permute0213(const GemmSoftmaxGemmPermuteParams* params) { - // Permute 0213 - // gemm2_out is B,N,S,H, transpose to out_buffer as B,S,N,H - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - return LaunchTransCtx( - params->StreamHandle(), - attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, - params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer); - } - - static Status GetSupportedStatus(const GemmSoftmaxGemmPermuteParams* params) { - const auto& attn = params->attention; - // TODO: address the BNMH k,v strides - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: - if (attn->qkv_format == Q_K_V_BNSH) { - return Status::OK(); - } else { - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, got ", - attn->qkv_format); - } - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH but k, v are BLNH"); - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - // If sequence_length is 1, query of B1NH can be simply viewed as BN1H. - if (attn->sequence_length == 1) { - return Status::OK(); - } else { - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, ", - "only if sequence_length is 1, query of BSNH can be viewed as BNSH"); - } - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH"); - default: - return TUNABLE_OP_UNSUPPORTED("unknonw"); - } - return TUNABLE_OP_UNSUPPORTED("unknonw case"); - } - - static Status Run(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { - auto supported_status = GetSupportedStatus(params); - if (!supported_status.IsOK()) { - return supported_status; - } - ORT_RETURN_IF_ERROR(Gemm1(params)); - - if (UseRawAttentionMask(params)) { - ORT_RETURN_IF_ERROR(SoftmaxRawMask(params, use_persistent_softmax)); - } else if (params->mask_index_dims.size() == 1) { // 1d index mask - ORT_RETURN_IF_ERROR(Softmax1DIndexMask(params)); - } else { - ORT_RETURN_IF_ERROR(SoftmaxNoMask(params)); - } - - ORT_RETURN_IF_ERROR(Gemm2(params)); - ORT_RETURN_IF_ERROR(Permute0213(params)); - return Status::OK(); - } -}; - -template -class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp> { - public: - GemmSoftmaxGemmPermuteTunableOp(); - - inline static bool IsSupportedMode(const RocmAttentionParameters* attn) { - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: - // depends on qkv format - if (attn->qkv_format == Q_K_V_BNSH || attn->qkv_format == Q_K_V_BSNH) { - return true; - } else { - return false; - } - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - return true; - default: - return false; - } - } - - inline static bool IsSupportedMaskType(const RocmAttentionParameters* attn) { - switch (attn->mask_type) { - case MASK_NONE: - case MASK_2D_DUMMY: - case MASK_2D_KEY_PADDING: - case MASK_3D_ATTENTION: - case MASK_4D_MEGATRON: - return true; - default: - return false; - } - } - - inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { - size_t num_bytes = GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(attn); - -#ifdef USE_COMPOSABLE_KERNEL - if (IsSupportedMaskType(attn)) { - auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn); - num_bytes = std::max(num_bytes, sizeof(T) * sizes.x * sizes.y * sizes.z); - } -#endif - - return num_bytes; - } - - template - __global__ static void ConvertToFilledMaskValue( - T* __restrict__ out, - const int3 out_strides, - const int* __restrict__ mask_buffer, - const int3 mask_lengths, // [B,S,T] - const int3 mask_strides, - Converter cvt) { - const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; - if (global_idx >= mask_lengths.x * mask_lengths.y * CeilDiv(mask_lengths.z, VecSize)) { - return; - } - - const int tidx = (global_idx % CeilDiv(mask_lengths.z, VecSize)) * VecSize; - const int bs_idx = global_idx / CeilDiv(mask_lengths.z, VecSize); - const int sidx = bs_idx % mask_lengths.y; - const int bidx = bs_idx / mask_lengths.y; - - int64_t in_offset = mask_strides.x * bidx + mask_strides.y * sidx + mask_strides.z * tidx; - int64_t out_offset = out_strides.x * bidx + out_strides.y * sidx + out_strides.z * tidx; - - if (tidx + VecSize <= mask_lengths.z) { - using LoadT = const aligned_vector; - using StoreT = aligned_vector; - LoadT load = *reinterpret_cast(mask_buffer + in_offset); - StoreT store; - -#pragma unroll - for (int i = 0; i < VecSize; i++) { - store.val[i] = cvt(load.val[i]); - } - *reinterpret_cast(out + out_offset) = store; - } else { -#pragma unroll - for (int i = 0; i < mask_lengths.z - tidx; i++) { - out[out_offset + i] = cvt(mask_buffer[in_offset + i]); - } - } - } - - static Status LaunchConvertToFilledMaskValue(const GemmSoftmaxGemmPermuteParams* params) { - constexpr const int kThreadPerBlock = 256; - constexpr const int kVecSize = 4; - - auto attn = params->attention; - auto [buffer, lengths, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); - int64_t total_threads = lengths.x * lengths.y * CeilDiv(lengths.z, kVecSize); - auto num_blocks = CeilDiv(total_threads, kThreadPerBlock); - - auto mask_filter_value = attn->mask_filter_value; - auto cvt = [=] __device__(int v) -> T { - return v == 1 ? 0 : mask_filter_value; - }; - - ConvertToFilledMaskValue<<StreamHandle()>>>( - reinterpret_cast(params->workspace_buffer), {lengths.y * lengths.z, lengths.z, 1}, // out desc - buffer, lengths, strides, // mask desc - cvt); - - return HIP_CALL(hipGetLastError()); - } -}; - -#ifdef USE_COMPOSABLE_KERNEL - -template -auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { - constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); - - using Nop = ck::tensor_operation::element_wise::PassThrough; - using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), - "attention mode is not supported, got ", params->attention->mode); - if constexpr (USE_BIAS) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer == nullptr, "biased version only support input with bias"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer != nullptr, "non-biased version only support input without bias"); - } - if constexpr (USE_MASK) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), - "mask type is not supported, got ", params->attention->mask_type); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer == nullptr, "masked version only support input with mask"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); - } - - auto attn = params->attention; - const int& G0 = attn->batch_size; - const int& G1 = attn->num_heads; - const int& M = attn->sequence_length; - const int& N = attn->total_sequence_length; - const int& K = attn->head_size; - const int& O = attn->v_head_size; - { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); - } - - auto [qs, ks, vs] = GetQkvStrides(attn); - std::vector q_buffer_lengths = {G0, G1, M, K}; - std::vector q_buffer_strides = qs.template ForBNSHCoord>(); - std::vector k_buffer_lengths = {G0, G1, N, K}; - std::vector k_buffer_strides = ks.template ForBNSHCoord>(); - std::vector v_buffer_lengths = {G0, G1, O, N}; - std::vector v_buffer_strides = vs.template ForBNHSCoord>(); - std::vector out_buffer_lengths = {G0, G1, M, O}; - std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 - - std::array bias_buffers{}; - std::array, kNumBiasBuffer> bias_lengths{}; - std::array, kNumBiasBuffer> bias_strides{}; - if constexpr (USE_BIAS) { - bias_buffers[0] = const_cast(params->bias_buffer); - bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - bias_strides[0] = {G1 * M * N, M * N, N, 1}; - } - if constexpr (USE_MASK) { - bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; - bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - if (params->mask_index_dims.size() == 2) { // [B,T] - bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; - } else if (params->mask_index_dims.size() == 3) { // [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else { - ORT_ENFORCE(false, "Unreachable"); - } - } - - auto arg = impl->MakeArgumentPointer( - params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, - bias_buffers, // Gemm1 bias, as attention mask - {}, // Gemm2 bias - q_buffer_lengths, q_buffer_strides, - k_buffer_lengths, k_buffer_strides, - v_buffer_lengths, v_buffer_strides, - out_buffer_lengths, out_buffer_strides, - bias_lengths, bias_strides, - {}, - {}, - Nop{}, - Nop{}, - Acc0ElementOp{params->scale}, - Nop{}, - Nop{}); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - - if constexpr (USE_MASK) { - ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); - } - - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); -} - -template -auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using D0DataType = typename ck::detail::tuple_concat< - std::conditional_t, ck::Tuple<>>, - std::conditional_t, ck::Tuple<>>>::type; - - constexpr static auto MaskingSpecMaskDisabled = - ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; - constexpr static auto MaskingSpecMaskOutUpperTriangle = - ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; - - std::vector>>> - ret; - - for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { - auto type_string = impl->GetTypeString(); - - auto invoker = impl->MakeInvokerPointer(); - auto op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GemmSoftmaxGemmPermuteParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); - - return GetArgAndRunInvoker(impl, invoker, params); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); - } - - for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { - auto type_string = impl->GetTypeString(); - - auto invoker = impl->MakeInvokerPointer(); - auto op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GemmSoftmaxGemmPermuteParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->attention->sequence_length != params->attention->total_sequence_length, - "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); - - return GetArgAndRunInvoker(impl, invoker, params); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); - } - - return ret; -} -#endif // USE_COMPOSABLE_KERNEL - -template -GemmSoftmaxGemmPermuteTunableOp::GemmSoftmaxGemmPermuteTunableOp() { - this->RegisterOp([](const GemmSoftmaxGemmPermuteParams* params) { - return GemmSoftmaxGemmPermuteGenericPipeline::Run(params, false); - }); - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } -#endif -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h deleted file mode 100644 index 0aff519d20e99..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "contrib_ops/cpu/bert/attention_common.h" -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - RocmTuningContext* tuning_ctx, // context for tuning - Stream* stream, // ORT Stream - hipblasHandle_t& hipblas, // hipblas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise.h b/onnxruntime/contrib_ops/rocm/bert/elementwise.h deleted file mode 100644 index 768295767835a..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise.h +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchElementwiseKernel(RocmTuningContext* tuning_ctx, Stream* stream, - const T* input, int input_length, - const T* bias, int bias_length, - T* output); - -// The following is LaunchElementwiseKernel implementation detail. Their interfaces are exposed for kernel explorer. -namespace internal { - -template -struct ElementwiseParams : OpParams { - ElementwiseParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, - const T* input, const T* bias, T* output, int input_length, int bias_length) - : OpParams(tuning_ctx, stream), - input(input), - bias(bias), - output(output), - input_length(input_length), - bias_length(bias_length) {} - - std::string Signature() const override { - std::string sig = std::to_string(input_length) + "_" + std::to_string(bias_length); - return sig; - } - - const T* input; - const T* bias; - T* output; - int input_length; - int bias_length; -}; - -template -class ElementwiseOp { - public: - Status operator()(const ElementwiseParams* params); - Status IsSupported(const ElementwiseParams* params); -}; - -template -Status ElementwiseStaticSelection(const ElementwiseParams* params); - -template -class ElementwiseTunableOp : public TunableOp> { - public: - ElementwiseTunableOp(); -}; - -} // namespace internal - -#define ELEMENTWISE_FWD_DECL(FnName, T) \ - namespace functor { \ - struct FnName; \ - } - -ELEMENTWISE_FWD_DECL(FastGeLU, float); -ELEMENTWISE_FWD_DECL(FastGeLU, double); -ELEMENTWISE_FWD_DECL(FastGeLU, half); -ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16); - -ELEMENTWISE_FWD_DECL(GeLU, float); -ELEMENTWISE_FWD_DECL(GeLU, double); -ELEMENTWISE_FWD_DECL(GeLU, half); -ELEMENTWISE_FWD_DECL(GeLU, BFloat16); - -ELEMENTWISE_FWD_DECL(ReLU, float); -ELEMENTWISE_FWD_DECL(ReLU, half); -ELEMENTWISE_FWD_DECL(ReLU, BFloat16); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh deleted file mode 100644 index 8255e70d27e48..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/tunable/util.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "contrib_ops/rocm/bert/elementwise.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace functor { - -struct FastGeLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - constexpr const float b = 0.7978845608028654f; // sqrt(2.0/M_PI) - - // const T cdf = a + a * _Tanh(in * (c * in * in + b)); - const T xb = x * T(b); - const T u = xb * T(0.044715f) * x * x + xb; - const T emu = __expf(-u - u); - const T cdf = T(1.0f) / (T(1.0f) + emu); - y = x * cdf; - } -}; - -struct GeLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - y = T(0.5f) * x * (T(1.f) + T(erf(0.70710678118f * float(x)))); - } -}; - -struct ReLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - y = x >= T{} ? x : T{}; - } -}; - -} // namespace functor - -using onnxruntime::rocm::CeilDiv; -using onnxruntime::rocm::GPU_WARP_SIZE; - -template -__global__ void ElementwiseKernel( - const T* __restrict__ input, int input_length, - const T* __restrict__ bias, int bias_length, - T* __restrict__ output) { - const int idx = blockIdx.x * TPB + threadIdx.x; - Fn f{}; - - if (idx < input_length) { - const T x = input[idx] + (bias == nullptr ? T{} : bias[idx % bias_length]); - f(output[idx], x); - } -} - -template -__global__ void ElementwiseKernelVec( - const T* __restrict__ input, int input_length, - const T* __restrict__ bias, int bias_length, - T* output) { - using VecT = onnxruntime::rocm::aligned_vector; - Fn f{}; - - const int idx = (blockIdx.x * TPB + threadIdx.x) * ILP; - if (idx < input_length) { - T input_v[ILP]; - VecT* input_val = reinterpret_cast(&input_v); - *input_val = *reinterpret_cast(&input[idx]); - T output_v[ILP]; - VecT* output_val = reinterpret_cast(&output_v); - T bias_v[ILP]; - if (bias != nullptr) { - VecT* bias_val = reinterpret_cast(&bias_v); - *bias_val = *reinterpret_cast(&bias[idx % bias_length]); - } - -#pragma unroll - for (int i = 0; i < ILP; i++) { - const T x = (bias == nullptr) ? input_v[i] : (T)(input_v[i] + bias_v[i]); - f(output_v[i], x); - } - *(reinterpret_cast(&output[idx])) = *output_val; - } -} - -template -Status LaunchElementwiseKernel( - RocmTuningContext* tuning_ctx, Stream* stream, - const T* input, int input_length, - const T* bias, int bias_length, - T* output) { - internal::ElementwiseParams params(tuning_ctx, stream, input, bias, output, input_length, bias_length); - if (tuning_ctx->IsTunableOpEnabled()) { - static internal::ElementwiseTunableOp op; - return op(¶ms); - } - - return internal::ElementwiseStaticSelection(¶ms); -} - -namespace internal { - -template -Status ElementwiseOp::operator()(const ElementwiseParams* params) { - dim3 blocks(CeilDiv(params->input_length, ThreadsPerBlock * VecSize)); - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, - params->bias, params->bias_length, - params->output); - return HIP_CALL(hipGetLastError()); -} - -template -Status ElementwiseOp::IsSupported(const ElementwiseParams* params) { - // TODO(anyone): Add tail handling for FastGelu - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) || - (params->bias_length == 0 && params->input_length % VecSize == 0))); - // Avoid redundant configurations - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->input_length > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)); - - return Status::OK(); -} - -template -Status ElementwiseStaticSelection(const ElementwiseParams* params) { - constexpr int block_size = 256; - if constexpr (std::is_same_v) { - if (params->bias != nullptr) { - if (0 == (params->bias_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 - const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->bias_length % 4)) { - const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->bias_length % 2)) { - const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - } else { - if (0 == (params->input_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 - const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->input_length % 4)) { - const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->input_length % 2)) { - const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - } - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - return HIP_CALL(hipGetLastError()); -} - -template -ElementwiseTunableOp::ElementwiseTunableOp() { - this->RegisterOp(ElementwiseStaticSelection); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); -} - -#undef ADD_OP - -} // namespace internal - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime - -#define ELEMENTWISE_KERNEL_IMPL(Fn, T) \ - namespace onnxruntime { \ - namespace contrib { \ - namespace rocm { \ - template Status LaunchElementwiseKernel( \ - RocmTuningContext * tuning_ctx, Stream* stream, \ - const T* input, int input_length, \ - const T* bias, int bias_length, \ - T* output); \ - namespace internal { \ - template class ElementwiseTunableOp; \ - } \ - } \ - } \ - } diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu deleted file mode 100644 index c2a670ea76aca..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu deleted file mode 100644 index 97f0f74640c6e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu deleted file mode 100644 index 67e50869133f5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc deleted file mode 100644 index fdb62d3a2aec5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/gemm_fast_gelu.h" - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" -#include "core/providers/cpu/math/matmul_helper.h" -#include "core/providers/rocm/rocm_common.h" - -using onnxruntime::rocm::ToHipType; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GemmFastGelu, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - GemmFastGelu); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) - -template -Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { - typedef typename ToHipType::MappedType HipT; - - const auto* X = ctx->Input(0); - const auto* W = ctx->Input(1); - const auto* bias = ctx->Input(2); - - bool transa = false; - bool transb = false; - bool trans_batch_a = false; - bool trans_batch_b = false; - - MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(X->Shape(), W->Shape(), transa, transb, trans_batch_a, trans_batch_b, false)); - - Tensor* Y = ctx->Output(0, helper.OutputShape()); - - // Bail out early if the output is going to be empty - if (Y->Shape().Size() == 0) - return Status::OK(); - - // gemmfastgelu only support alpha == 1 and beta == 0 - const HipT alpha = ToHipType::FromFloat(1.0f); - const HipT beta = ToHipType::FromFloat(0.0f); - - using onnxruntime::rocm::tunable::blas::BlasOp; - - return blas::row_major::GemmFastGelu( - GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), - transa ? BlasOp::Trans : BlasOp::NonTrans, - transb ? BlasOp::Trans : BlasOp::NonTrans, - helper.M(), helper.N(), helper.K(), - alpha, - reinterpret_cast(X->Data()), helper.Lda(transa), - reinterpret_cast(W->Data()), helper.Ldb(transb), - (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, - beta, - reinterpret_cast(Y->MutableData()), helper.Ldc()); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h deleted file mode 100644 index ae4f84fa5f033..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using onnxruntime::rocm::RocmKernel; - -template -class GemmFastGelu final : public RocmKernel { - public: - GemmFastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {} - Status ComputeInternal(OpKernelContext* ctx) const override; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh deleted file mode 100644 index 77f53f9eed027..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#ifdef USE_COMPOSABLE_KERNEL -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" - -using onnxruntime::rocm::ToHipType; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { -namespace internal { - -#ifdef USE_COMPOSABLE_KERNEL - -using onnxruntime::rocm::CKBlasOpAdaptor; -using onnxruntime::rocm::CKDataTypeAdaptor; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using Nop = ck::tensor_operation::element_wise::PassThrough; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; -using FastGelu = ck::tensor_operation::element_wise::FastGelu; - -template -auto GetCKGemmAddFastGeluTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< - ALayout, BLayout, ck::Tuple, Row, - CKDataType, CKDataType, ck::Tuple, CKDataType, - Nop, Nop, AddFastGelu>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = onnxruntime::MakeString("withbias ", impl->GetTypeString()); - - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero || params->bias == nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); - - auto nop = Nop{}; - auto addfastgelu = AddFastGelu{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, std::array{params->bias}, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, std::array{0}, params->ldc, - nop, nop, addfastgelu); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); - } - return ret; -} - -template -auto GetCKGemmFastGeluTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< - ALayout, BLayout, ck::Tuple<>, Row, - CKDataType, CKDataType, ck::Tuple<>, CKDataType, - Nop, Nop, FastGelu>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = onnxruntime::MakeString("nobias ", impl->GetTypeString()); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero || params->bias != nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); - - auto nop = Nop{}; - auto fastgelu = FastGelu{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, - {}, - params->c, - params->m, params->n, params->k, - params->lda, params->ldb, - {}, - params->ldc, - nop, nop, fastgelu); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); - } - return ret; -} -#else -struct Row {}; -struct Col {}; -#endif // USE_COMPOSABLE_KERNEL - -} // namespace internal -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h deleted file mode 100644 index 2b8a21b83f177..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::blas::BlasOp; -using onnxruntime::rocm::tunable::blas::BlasOpToString; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -template -struct GemmFastGeluParams : OpParams { - std::string Signature() const override { - bool has_bias = (nullptr != bias) ? 0 : 1; - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias); - } - hipblasHandle_t handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - T alpha; - const T* a; - int64_t lda; - const T* b; - int64_t ldb; - const T* bias; - T beta; - T* c; - int64_t ldc; -}; - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu deleted file mode 100644 index 8d7e64b1015be..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES -#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" - -#include -#include - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" -#include "core/providers/rocm/shared_inc/fpgeneric.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -namespace row_major { - -template -inline GEMMFASTGELU(T, ScalarT) { - GemmFastGeluParams params; - params.tuning_ctx = tuning_ctx; - params.stream = stream; - params.handle = handle; - - params.opa = opa; - params.opb = opb; - params.m = m; - params.n = n; - params.k = k; - if constexpr (!std::is_same_v && std::is_same_v) { - params.alpha = ToHipType::FromFloat(std::forward(alpha)); - } else { - params.alpha = alpha; - } - params.a = a; - params.lda = lda; - params.b = b; - params.ldb = ldb; - params.bias = bias; - if constexpr (!std::is_same_v && std::is_same_v) { - params.beta = ToHipType::FromFloat(std::forward(beta)); - } else { - params.beta = beta; - } - params.c = c; - params.ldc = ldc; - - if (tuning_ctx->IsTunableOpEnabled()) { - if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } - } - - return internal::GemmFastGeluUnfused(¶ms); -} - -#define CALL_GEMMFASTGELU(T, ScalarT) \ - GemmFastGelu(tuning_ctx, stream, handle, \ - opa, opb, \ - m, n, k, \ - alpha, a, lda, b, ldb, bias, \ - beta, c, ldc) - -// clang-format off -GEMMFASTGELU(float, float ) { return CALL_GEMMFASTGELU(float, float ); } -GEMMFASTGELU(half, half ) { return CALL_GEMMFASTGELU(half, half ); } -GEMMFASTGELU(BFloat16, BFloat16) { return CALL_GEMMFASTGELU(BFloat16, BFloat16); } -GEMMFASTGELU(half, float ) { return CALL_GEMMFASTGELU(half, float ); } -GEMMFASTGELU(BFloat16, float ) { return CALL_GEMMFASTGELU(BFloat16, float ); } -// clang-format on - -#undef CALL_GEMMFASTGELU - -} // namespace row_major - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h deleted file mode 100644 index b707c63ef44be..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "core/common/status.h" -#include "core/common/float16.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -#define GEMMFASTGELU(T, ScalarT) \ - common::Status GemmFastGelu( \ - RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ - BlasOp opa, BlasOp opb, \ - std::int64_t m, std::int64_t n, std::int64_t k, \ - ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ - const T* bias, ScalarT beta, T* c, std::int64_t ldc) - -namespace row_major { - -GEMMFASTGELU(float, float); -GEMMFASTGELU(half, half); -GEMMFASTGELU(BFloat16, BFloat16); -GEMMFASTGELU(half, float); -GEMMFASTGELU(BFloat16, float); - -} // namespace row_major - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime - -#ifndef _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES -#undef GEMMFASTGELU -#endif diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh deleted file mode 100644 index e157aa57f8c43..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "contrib_ops/rocm/bert/elementwise.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/gemm_hipblaslt.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { -namespace internal { - -using namespace onnxruntime::rocm::tunable::blas::internal; - -template -Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { - namespace column_major = onnxruntime::rocm::tunable::blas::column_major; - ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning_ctx, params->stream, params->handle, - params->opb, params->opa, - params->n, params->m, params->k, - params->alpha, params->b, params->ldb, params->a, params->lda, - params->beta, params->c, params->ldc)); - - int64_t fast_gelu_input_length = params->m * params->n; - int64_t bias_length = (params->bias != nullptr) ? params->n : 0; - - // Because of GemmFastGeluUnfused is a combination of GemmOp and FastGeluOp, FastGeluOp in this combination is - // an inplace computation. - // 1. If we call GemmFastGeluUnfused directly with enabled tuning, it may cause the input buffer of FastGelu been - // updated accumulatedly and result in incorrect result finally. This only happens if the tuning's FindFastest is invoked. - // 2. It's safe to call GemmFastGeluUnfused with disabled tuning, FastGelu only run once and produce correct result. - // 3. It's safe to call GemmFastGeluUnfused as part of GemmFastGeluTunableOp with enable tuning, GemmTunableOp and - // FastGeluTunableOp will do tune in first warmup step separately during GemmFastGeluUnfused profiling process. - // After that, the call to GemmFastGeluUnfused not invoke tuning's FindFastest of FastGelu. - // - // Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp - // to protect original input value. - return onnxruntime::contrib::rocm::LaunchElementwiseKernel( - params->tuning_ctx, params->Stream(), - params->c, static_cast(fast_gelu_input_length), - params->bias, static_cast(bias_length), - params->c); -} - -template -class GemmFastGeluTunableOp : public TunableOp> { - public: - GemmFastGeluTunableOp() { - this->RegisterOp(GemmFastGeluUnfused); -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - -#ifdef USE_HIPBLASLT - for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } -}; - -} // namespace internal -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu deleted file mode 100644 index 09a6550549614..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ /dev/null @@ -1,530 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/platform/env_var_utils.h" -#include "contrib_ops/rocm/bert/group_query_attention.h" -#include "contrib_ops/cpu/bert/group_query_attention_helper.h" -#include "contrib_ops/rocm/bert/rotary_embedding_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" - -#ifdef USE_COMPOSABLE_KERNEL_CK_TILE -#include "ck_tile/core/numeric/integer.hpp" -#include "fmha_fwd.hpp" -#endif - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GroupQueryAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ - .MayInplace(3, 1) \ - .MayInplace(4, 2) \ - .InputMemoryType(OrtMemTypeCPUInput, 6), \ - GroupQueryAttention); - -// REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -// REGISTER_KERNEL_TYPED(BFloat16) - -template -std::string GetCkFmhaDataTypeString(); - -template <> -std::string GetCkFmhaDataTypeString() { - return "fp16"; -} - -template <> -std::string GetCkFmhaDataTypeString() { - return "bf16"; -} - -__global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int inc) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx < num_elems) { - out[idx] = seqlens[idx] + inc; - } -} - -Status LaunchSeqlensInc(hipStream_t stream, const int* seqlens, int* out, int num_elems, int inc) { - constexpr int NumThreads = 128; - int num_blks = CeilDiv(num_elems, NumThreads); - seqlens_inc_kernel<<>>(seqlens, out, num_elems, inc); - return HIP_CALL(hipGetLastError()); -} - -__global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx < num_elems) { - out[idx] = idx * length_per_seq; - } - if (idx == 0) { - out[num_elems] = num_elems * length_per_seq; - } -} - -Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int length_per_seq) { - constexpr int NumThreads = 128; - int num_blks = CeilDiv(num_elems, NumThreads); - seqstart_init_kernel<<>>(out, num_elems, length_per_seq); - return HIP_CALL(hipGetLastError()); -} - -// Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, - const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - int b = tid / seqlen; - int s = tid % seqlen; - if (b < batch_size) { - if (s < seqlens_k[b] + 1) { - position_ids[tid] = s; - } else { - position_ids[tid] = 1; - } - } -} - -// Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid < batch_size) { - position_ids[tid] = seqlens_k[tid]; - } -} - -// Convert seqlens_k to position_ids -Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, - int64_t* position_ids, hipStream_t stream, const int max_threads_per_block) { - const int seqlen = parameters.sequence_length; - const int batch_size = parameters.batch_size; - const int threads = max_threads_per_block; - const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_first_prompt) { - SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); - } else { - SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); - } - return HIP_CALL(hipGetLastError()); -} - -template -GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) - : RocmKernel(info) { - int64_t num_heads = 0; - int64_t kv_num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); - num_heads_ = static_cast(num_heads); - kv_num_heads_ = static_cast(kv_num_heads); - is_past_bsnh_ = false; - is_unidirectional_ = true; - local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); - do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; - rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; - scale_ = info.GetAttrOrDefault("scale", 0.0f); -} - -template <> -std::once_flag GroupQueryAttention::arch_checking_{}; - -template <> -std::once_flag GroupQueryAttention::arch_checking_{}; - -template -Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { -#if USE_COMPOSABLE_KERNEL_CK_TILE - auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); - const Tensor* query = ctx->Input(0); - const Tensor* key = ctx->Input(1); - const Tensor* value = ctx->Input(2); - const Tensor* past_key = ctx->Input(3); - const Tensor* past_value = ctx->Input(4); - const Tensor* seqlens_k = ctx->Input(5); - const Tensor* total_seqlen = ctx->Input(6); - const Tensor* cos_cache = ctx->Input(7); - const Tensor* sin_cache = ctx->Input(8); - - auto& device_prop = GetDeviceProp(); - std::call_once( - arch_checking_, - [](const hipDeviceProp_t& device_prop) { - if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && - std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { - LOGS_DEFAULT(WARNING) - << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " - << "CDNA2 and CDNA3 archs."; - LOGS_DEFAULT(WARNING) - << "GroupQueryAttention running on an unsuppoted GPU may result in " - << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; - } - }, - device_prop); - - GroupQueryAttentionParameters parameters; - using HipT = typename ToHipType::MappedType; - - const int max_thr_per_blk = device_prop.maxThreadsPerBlock; - - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, - key, - value, - past_key, - past_value, - cos_cache, - sin_cache, - ¶meters, - num_heads_, - kv_num_heads_, - seqlens_k, - total_seqlen, - is_past_bsnh_, - scale_, - max_thr_per_blk)); - - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; - const int kv_num_heads = parameters.kv_num_heads; - const int head_size = parameters.head_size; - AttentionQkvFormat past_kv_format = parameters.past_kv_format; - - parameters.local_window_size = local_window_size_; - parameters.is_unidirectional = is_unidirectional_; - // parameters.zeros_count = kZerosCount; - // parameters.zero_ptr = zeros_.get(); - // parameters.left_padding = left_padding_; - parameters.do_rotary = do_rotary_; - parameters.rotary_interleaved = rotary_interleaved_; - - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( - context->OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); - - if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); - } - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(batch_size); - output_shape[1] = static_cast(sequence_length); - output_shape[2] = static_cast(parameters.hidden_size); - Tensor* output = ctx->Output(0, output_shape); - Strides output_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - - int4 past_shape; - std::vector present_dims; - Strides present_strides; - Strides past_strides; - if (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { - past_shape = { - batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size}; - past_strides = Strides::BSNHMemory( - batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size); - present_dims = { - batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size}; - present_strides = Strides::BSNHMemory( - batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); - } else { // BNSH - past_shape = { - batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size}; - past_strides = Strides::BNSHMemory( - batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size); - present_dims = { - batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size}; - present_strides = Strides::BNSHMemory( - batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size); - } - TensorShape present_shape(present_dims); - Tensor* present_key = ctx->Output(1, present_shape); - Tensor* present_value = ctx->Output(2, present_shape); - - Strides query_strides; - Strides key_strides; - Strides value_strides; - int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; // BNSH coord - const HipT* query_ptr = reinterpret_cast(query->DataRaw()); - const HipT* key_ptr; - const HipT* value_ptr; - if (!parameters.is_packed_qkv) { - query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - key_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); - value_strides = key_strides; - key_ptr = reinterpret_cast(key->DataRaw()); - value_ptr = reinterpret_cast(value->DataRaw()); - } else { - query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); - key_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); - value_strides = query_strides; - const size_t key_offset = static_cast(num_heads * head_size); - const size_t value_offset = static_cast(kv_num_heads * head_size); - key_ptr = query_ptr + key_offset; - value_ptr = key_ptr + value_offset; - } - - IAllocatorUniquePtr rotary_q_tmp; - IAllocatorUniquePtr rotary_k_tmp; - if (parameters.do_rotary) { - size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); - size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); - auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); - - rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); - rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); - auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, - reinterpret_cast(seqlens_k->DataRaw()), - reinterpret_cast(rotary_position_ids_tmp.get()), - hip_stream, max_thr_per_blk)); - // Launch rotary embedding kernel - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_q_tmp.get(), query_ptr, - reinterpret_cast(rotary_position_ids_tmp.get()), - reinterpret_cast(cos_cache->DataRaw()), - reinterpret_cast(sin_cache->DataRaw()), - parameters.batch_size, parameters.sequence_length, - parameters.num_heads, parameters.head_size, - parameters.rotary_dim, parameters.seqlen_present_kv_cache, - /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, - query_strides.ForBNSHCoord(), - rotary_q_strides.ForBNSHCoord())); - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, - reinterpret_cast(rotary_position_ids_tmp.get()), - reinterpret_cast(cos_cache->DataRaw()), - reinterpret_cast(sin_cache->DataRaw()), - parameters.batch_size, parameters.sequence_length, - parameters.kv_num_heads, parameters.head_size, - parameters.rotary_dim, parameters.seqlen_present_kv_cache, - /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, - key_strides.ForBNSHCoord(), - rotary_k_strides.ForBNSHCoord())); - query_ptr = reinterpret_cast(rotary_q_tmp.get()); - key_ptr = reinterpret_cast(rotary_k_tmp.get()); - query_strides = rotary_q_strides; - key_strides = rotary_k_strides; - } - - const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; - IAllocatorUniquePtr seqlens_k_tmp; - - // build present kv cache - auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); - auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); - if (parameters.is_first_prompt) { - // copy prompt kv to present kv - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), - present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), - present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - } else { - const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); - const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); - parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: - if (!parameters.kv_share_buffer) { - // copy past to present, - // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are - // not the same, aka, can not be as simple as strided - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_key_ptr, past_shape, past_strides.ForBNSHCoord(), - present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), - present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - } else { - // In the case of share buffer - ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); - ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); - } - // then append new kv to present - size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, - present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, - max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, - present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, - max_thr_per_blk)); - - // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. - // we should call fmha with total sequence lengths - seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); - seqlens_k_ptr = seqlens_k_tmp.get(); - } - static_assert(std::is_same_v); - - const float scale = parameters.scale == 0.0f - ? 1.f / sqrt(static_cast(parameters.head_size)) - : parameters.scale; - bias_enum bias_type = bias_enum::no_bias; - - mask_info mask = [&]() { - if (local_window_size_ != -1) { - mask_info ret; - ret.type = mask_enum::window_generic; - ret.left = local_window_size_; - ret.right = parameters.is_unidirectional ? 0 : -1; - // ret.x = kv_sequence_length - (sequence_length - ret.left); - // ret.y = sequence_length + (ret.right - kv_sequence_length); - return ret; - } - - if (parameters.is_first_prompt && is_unidirectional_) { - return mask_info::decode("t", sequence_length, kv_sequence_length); - } - - return mask_info::decode("0", sequence_length, kv_sequence_length); - }(); - - auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); - auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit( - hip_stream, seqstart_q_tmp.get(), batch_size, - query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit( - hip_stream, seqstart_k_tmp.get(), batch_size, - present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); - - fmha_fwd_args args{ - query_ptr, - present_key->DataRaw(), - present_value->DataRaw(), - nullptr, // bias, alibi/element - nullptr, // lse, logsumexp buffer - output->MutableDataRaw(), - seqstart_q_tmp.get(), // seqstart_q_ptr, for group mode - seqstart_k_tmp.get(), // seqstart_k_ptr, for group mode - seqlens_k_ptr, // seqlen_k_ptr, for group mode - sequence_length, // seqlen_q, for batch mode - kv_sequence_length, // seqlen_k, for batch mode - parameters.batch_size, // batch - parameters.sequence_length, // max_seqlen_q - parameters.head_size, // hdim_q - parameters.head_size, // hdim_v - parameters.num_heads, - parameters.kv_num_heads, - scale, - 1.0f, // scale_p of squant, useless - 1.0f, // scale_o of squant, useless - static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S - static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S - static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S - batch_size, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 - static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S - static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N - static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N - static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N - 0, // nhead_stride_bias - batch_size, // nhead_stride_lse - static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B - static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B - static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B - static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B - 0, // batch_stride_bias - num_heads * batch_size, // batch_stride_lse - static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B - mask.left, // window_size_left - mask.right, // window_size_right - static_cast(mask.type)}; - -#if 0 - std::cout - << "\n sequence_length:" << sequence_length - << "\n kv_sequence_length:" << kv_sequence_length - << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache - << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; - - std::cout - << "\n q_ptr:" << args.q_ptr - << "\n k_ptr:" << args.k_ptr - << "\n v_ptr:" << args.v_ptr - << "\n bias_ptr:" << args.bias_ptr - << "\n lse_ptr:" << args.lse_ptr - << "\n o_ptr:" << args.o_ptr - << "\n seqstart_q_ptr:" << args.seqstart_q_ptr - << "\n seqstart_k_ptr:" << args.seqstart_k_ptr - << "\n seqlen_k_ptr:" << args.seqlen_k_ptr - << "\n seqlen_q:" << args.seqlen_q - << "\n seqlen_k:" << args.seqlen_k - << "\n batch:" << args.batch - << "\n max_seqlen_q:" << args.max_seqlen_q - << "\n hdim_q:" << args.hdim_q - << "\n hdim_v:" << args.hdim_v - << "\n nhead_q:" << args.nhead_q - << "\n nhead_k:" << args.nhead_k - << "\n scale_s:" << args.scale_s - << "\n scale_p:" << args.scale_p - << "\n scale_o:" << args.scale_o - << "\n stride_q:" << args.stride_q - << "\n stride_k:" << args.stride_k - << "\n stride_v:" << args.stride_v - << "\n stride_bias:" << args.stride_bias - << "\n stride_o:" << args.stride_o - << "\n nhead_stride_q:" << args.nhead_stride_q - << "\n nhead_stride_k:" << args.nhead_stride_k - << "\n nhead_stride_v:" << args.nhead_stride_v - << "\n nhead_stride_bias:" << args.nhead_stride_bias - << "\n nhead_stride_lse:" << args.nhead_stride_lse - << "\n nhead_stride_o:" << args.nhead_stride_o - << "\n batch_stride_q:" << args.batch_stride_q - << "\n batch_stride_k:" << args.batch_stride_k - << "\n batch_stride_v:" << args.batch_stride_v - << "\n batch_stride_bias:" << args.batch_stride_bias - << "\n batch_stride_lse:" << args.batch_stride_lse - << "\n batch_stride_o:" << args.batch_stride_o - << "\n window_size_left:" << args.window_size_left - << "\n window_size_right:" << args.window_size_right - << "\n mask_type:" << args.mask_type - << std::endl; -#endif - - fmha_fwd_traits traits{ - parameters.head_size, - parameters.head_size, // v head size - GetCkFmhaDataTypeString(), - !parameters.is_first_prompt, // true, // is_group_mode - true, // is_v_rowmajor ? dim is fastest : seq is fastest - mask.type, - bias_type, - false, // has_lse - false, // do_fp8_static_quant, aka, squant - }; - - ck_tile::stream_config stream_config{ - hip_stream, - false // time_kernel - }; - - auto duration = fmha_fwd(traits, args, stream_config); - if (duration < 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "fmha_fwd internal error"); - } - HIP_RETURN_IF_ERROR(hipGetLastError()); - - return Status::OK(); -#else - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tile to be enabled"); -#endif -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h deleted file mode 100644 index ce0de1f761aa5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class GroupQueryAttention final : public RocmKernel { - public: - GroupQueryAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - int num_heads_; // number of attention heads - int kv_num_heads_; // different for k and v for group query attention - int local_window_size_; - bool is_unidirectional_; - bool is_past_bsnh_; - bool do_rotary_; - bool rotary_interleaved_; - float scale_; - - private: - static std::once_flag arch_checking_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh deleted file mode 100644 index 2eeb7c3e8f279..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh +++ /dev/null @@ -1,270 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on bert plugins in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#pragma once - -#include -#include -#include -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/shared_inc/rocm_call.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -__device__ inline T Rsqrt(const T& x); - -template <> -__device__ inline float Rsqrt(const float& x) { - return rsqrtf(x); -} - -template <> -__device__ inline half Rsqrt(const half& x) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - return hrsqrt(x); -#else - return half(rsqrtf(static_cast(x))); -#endif -} - -__device__ inline half2 AddHalf2(const half2 a, const half2 b) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - return __hadd2(a, b); -#else - return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y)); -#endif -} - -struct KeyValuePairSum { - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - return hipcub::KeyValuePair(a.key + b.key, a.value + b.value); - } - - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - const half2 a2 = __halves2half2(a.key, a.value); - const half2 b2 = __halves2half2(b.key, b.value); - const half2 res = AddHalf2(a2, b2); - return hipcub::KeyValuePair(__low2half(res), __high2half(res)); - } - - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - return hipcub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); - } -}; - -template -__device__ inline void LayerNorm( - const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = static_cast(output[idx]); - const U g = static_cast(gamma[i]); - const U b = (nullptr == beta) ? U(0.f) : static_cast(beta[i]); - output[idx] = static_cast(g * (val - mu) * rsigma + b); - } -} - -template -__device__ inline void SimplifiedLayerNorm( - const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = static_cast(output[idx]); - const U g = static_cast(gamma[i]); - output[idx] = static_cast(g * val * rsigma); - } -} - -template -__device__ inline void SimplifiedLayerNormVec( - const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - const VecV gamma_v = *reinterpret_cast(gamma + i); - VecV output_v = *reinterpret_cast(output + idx); - -#pragma unroll - for (int k = 0; k < ILP; k++) { - output_v.val[k] = U(gamma_v.val[k]) * U(output_v.val[k]) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } - } -} - -template -__device__ inline void LayerNormVec( - const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + i) : VecV(); - const VecV gamma_v = *reinterpret_cast(gamma + i); - VecV output_v = *reinterpret_cast(output + idx); - -#pragma unroll - for (int k = 0; k < ILP; k++) { - output_v.val[k] = (beta != nullptr) ? U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma + U(beta_v.val[k]) : U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } - } -} - -template -__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair& thread_data, - const int ld, const int idx, const V* beta, const V* gamma, - const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - // Small settings: the block covers the leading dimension TPB >= ld. The input - // value is available in a register - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const hipcub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + threadIdx.x * ILP) : VecV(); - const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); - VecV output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - output_v.val[i] = (beta != nullptr) ? U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma + U(beta_v.val[i]) : U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } -} - -template -__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const U& thread_data, const int ld, const int idx, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - // Small settings: the block covers the leading dimension TPB >= ld. The input - // value is available in a register - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); - VecV output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - output_v.val[i] = U(gamma_v.val[i]) * U(input_v[i]) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu deleted file mode 100644 index 5d4ef53b8ba97..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/multihead_attention.h" - -#include "contrib_ops/cpu/bert/multihead_attention_helper.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" -#include "core/platform/env_var_utils.h" -#include "core/providers/rocm/rocm_common.h" - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_MHA_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - MultiHeadAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - MultiHeadAttention) - -REGISTER_MHA_KERNEL_TYPED(float); -REGISTER_MHA_KERNEL_TYPED(MLFloat16); - -static constexpr int kPastSequenceLengthInputIndex = 7; -static constexpr int kBeamWidthInputIndex = 8; -static constexpr int kPastInputIndex = 5; -static constexpr int kPresentOutputIndex = 1; - -#define REGISTER_DMMHA_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - DecoderMaskedMultiHeadAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(kPastInputIndex, kPresentOutputIndex) \ - .MayInplace(kPastInputIndex + 1, kPresentOutputIndex + 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \ - .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \ - MultiHeadAttention) - -REGISTER_DMMHA_KERNEL_TYPED(float); -REGISTER_DMMHA_KERNEL_TYPED(MLFloat16); - -template -MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) - : RocmKernel(info), - attn_type_(info.node().OpType() == "DecoderMaskedMultiHeadAttention" ? kDecoderMaskedMultiHeadAttention - : kMultiHeadAttention) { - int64_t num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - num_heads_ = static_cast(num_heads); - - mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - - scale_ = info.GetAttrOrDefault("scale", 0.0f); - - past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - tunable_op_ = std::make_shared(); -} - -template -Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { - ORT_ENFORCE( - GetTuningContext()->IsTunableOpEnabled(), - "MultiHeadAttention of ROCm EP is only supported if tunable op is used and tuning is enabled."); - - const Tensor* query = context->Input(0); - const Tensor* key = context->Input(1); - const Tensor* value = context->Input(2); - - const Tensor* bias{}; - const Tensor* key_padding_mask{}; - const Tensor* attention_bias{}; - const Tensor* past_key{}; - const Tensor* past_value{}; - const Tensor* past_seq_len{}; - - const Tensor* cache_indirection = nullptr; - - if (attn_type_ == kMultiHeadAttention) { - bias = context->Input(3); - key_padding_mask = context->Input(4); - attention_bias = context->Input(5); - past_key = context->Input(6); - past_value = context->Input(7); - } else if (attn_type_ == kDecoderMaskedMultiHeadAttention) { - key_padding_mask = context->Input(3); - attention_bias = context->Input(4); - past_key = context->Input(5); - past_value = context->Input(6); - past_seq_len = context->Input(kPastSequenceLengthInputIndex); - // const Tensor* beam_width = context->Input(8); // NOTE: not used - // const Tensor* cache_indirection = context->Input(9); // TODO: should not present for ROCm EP - bias = context->Input(10); - } - - if (nullptr != bias) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "qkv_bias is not supported on ROCm EP. " - "User should fuse the qkv bias to qkv projection instead."); - } - - auto& device_prop = GetDeviceProp(); - RocmAttentionParameters attn; - ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, - key, - value, - bias, - key_padding_mask, - attention_bias, - past_key, - past_value, - cache_indirection, - past_seq_len, - &attn, /* parameters */ - num_heads_, - mask_filter_value_, - scale_, - is_unidirectional_, - past_present_share_buffer_, - attn_type_, - device_prop.maxThreadsPerBlock)); - - if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); - } - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(attn.batch_size); - output_shape[1] = static_cast(attn.sequence_length); - output_shape[2] = static_cast(attn.v_hidden_size); - Tensor* output = context->Output(0, output_shape); - - std::vector present_dims{ - attn.batch_size, - attn.num_heads, - past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, - attn.head_size, - }; - TensorShape present_shape(present_dims); - Tensor* present_key = context->Output(1, present_shape); - Tensor* present_value = context->Output(2, present_shape); - - ORT_RETURN_IF_ERROR(ClassifyAttentionMode( - attn_type_, &attn, - /*qkv=*/{query, key, value}, - /*past=*/{past_key, past_value}, - /*present=*/{present_key, present_value})); - - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn); - auto workspace = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - hipStream_t stream = Stream(context); - if (nullptr != present_key) { // process past present concat - Strides dst_strides; - - int4 past_shape; - Strides past_src_strides; - const HipT* past_key_src; - const HipT* past_value_src; - HipT* past_key_dst{}; - HipT* past_value_dst{}; - - int4 add_shape; - Strides add_src_strides; - const HipT* add_key_src = reinterpret_cast(key->DataRaw()); - const HipT* add_value_src = reinterpret_cast(value->DataRaw()); - HipT* add_key_dst; - HipT* add_value_dst; - - if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH || - attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - - past_shape = {attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; - past_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); - past_key_src = reinterpret_cast(past_key->DataRaw()); - past_value_src = reinterpret_cast(past_value->DataRaw()); - past_key_dst = reinterpret_cast(present_key->MutableDataRaw()); - past_value_dst = reinterpret_cast(present_value->MutableDataRaw()); - - if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH || - attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - - if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else if ( - attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || - attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || - attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH || - attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - - if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "past present concatenation is not implemented for attention mode ", attn.mode); - } - add_shape = {attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size}; // kernel in coord (b,n,s,h) - add_key_dst = reinterpret_cast(present_key->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - add_value_dst = reinterpret_cast(present_value->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - - if (past_key_dst) { - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, past_key_src, past_shape, past_src_strides.ForBNSHCoord(), - past_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - if (past_value_dst) { - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, past_value_src, past_shape, past_src_strides.ForBNSHCoord(), - past_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, add_key_src, add_shape, add_src_strides.ForBNSHCoord(), - add_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, add_value_src, add_shape, add_src_strides.ForBNSHCoord(), - add_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - - GemmSoftmaxGemmPermuteParams params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = GetHipblasHandle(context); - params.attention = &attn; - params.device_prop = &device_prop; - params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_; - std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews( - &attn, - nullptr == query ? nullptr : reinterpret_cast(query->DataRaw()), - nullptr == key ? nullptr : reinterpret_cast(key->DataRaw()), - nullptr == value ? nullptr : reinterpret_cast(value->DataRaw()), - nullptr == present_key ? nullptr : reinterpret_cast(present_key->DataRaw()), - nullptr == present_value ? nullptr : reinterpret_cast(present_value->DataRaw())); - params.out_buffer = reinterpret_cast(output->MutableDataRaw()); - - if (key_padding_mask != nullptr) { - params.mask_index_buffer = key_padding_mask->Data(); - params.mask_index_dims = key_padding_mask->Shape().AsShapeVector(); - } - - if (attention_bias != nullptr) { - params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); - } - - params.workspace_buffer = reinterpret_cast(workspace.get()); - return (*std::static_pointer_cast(tunable_op_))(¶ms); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h deleted file mode 100644 index 1d676d7a7bcac..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/rocm/bert/attention_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class MultiHeadAttention final : public RocmKernel { - public: - MultiHeadAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - AttentionType attn_type_; - int num_heads_; // number of attention heads - float mask_filter_value_; - float scale_; - bool past_present_share_buffer_{false}; - bool is_unidirectional_{false}; - - // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: - // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. - // 2. We don't want to construct the object repeatly (which is expansive) during Compute. - std::shared_ptr tunable_op_; -}; - -template -class DecoderMaskedMultiHeadAttention final : public RocmKernel { - public: - DecoderMaskedMultiHeadAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - AttentionType mha_type; - int num_heads_; // number of attention heads - float mask_filter_value_; - float scale_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc deleted file mode 100644 index 9e649fb591896..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/skip_layer_norm.h" - -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - SkipLayerNormalization, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - SkipSimplifiedLayerNormalization, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -using namespace ONNX_NAMESPACE; - -template -SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { - ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); - ORT_ENFORCE(epsilon_ >= 0); -} - -template -Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { - const Tensor* input = ctx->Input(0); - const Tensor* skip = ctx->Input(1); - const Tensor* gamma = ctx->Input(2); - - const Tensor* beta = Simplified ? nullptr : ctx->Input(3); - const Tensor* bias = Simplified ? ctx->Input(3) : ctx->Input(4); - - Tensor* output = ctx->Output(0, input->Shape()); - - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); - - if (input->Shape() != skip->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "skip is expected to have same shape as input"); - } - - if (input->Shape().Size() == 0) { - return Status::OK(); - } - - const auto& input_dims = input->Shape().GetDims(); - size_t input_dims_size = input_dims.size(); - if (input_dims_size != 3 && input_dims_size != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size); - } - - int hidden_size = static_cast(input_dims[input_dims_size - 1]); - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of gamma and input does not match"); - } - - if (nullptr != beta) { - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of beta and input does not match"); - } - } - - if (nullptr != bias) { - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "bias is expected to have 1 dimension, got ", bias_dims.size()); - } - if (bias_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of bias and input does not match"); - } - } - - int64_t element_count = input->Shape().Size(); - typedef typename ToHipType::MappedType HipT; - - return LaunchSkipLayerNormKernel( - GetTuningContext(), - ctx->GetComputeStream(), - reinterpret_cast(output->MutableData()), - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, - reinterpret_cast(input->Data()), - reinterpret_cast(skip->Data()), - reinterpret_cast(gamma->Data()), - (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, - epsilon_, - hidden_size, - static_cast(element_count)); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h deleted file mode 100644 index 02228bc59cedc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class SkipLayerNorm final : public RocmKernel { - public: - SkipLayerNorm(const OpKernelInfo& op_kernel_info); - Status ComputeInternal(OpKernelContext* context) const override; - - private: - float epsilon_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu deleted file mode 100644 index 8387c49a3310b..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu +++ /dev/null @@ -1,86 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on skipLayerNorm plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// Modifications: Add SkipLayerNormKernelVec to -// leverage vectorized load/write. -// and templatize ComputeSkipLayerNorm for different -// data types. -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" - -#include - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" -#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, - const T* skip, const V* gamma, const V* beta, const T* bias, float epsilon, int ld, int element_count) { - // this must be true because element_count is the total size of the tensor - assert(element_count % ld == 0); - - SkipLayerNormParams params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, - gamma, beta, bias, epsilon, ld, element_count); - - if (tuning_ctx->IsTunableOpEnabled()) { - static SkipLayerNormTunableOp op; - return op(¶ms); - } - - return SkipLayerNormStaticSelection(¶ms); -} - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, - const float* skip, const float* gamma, const float* beta, - const float* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, - const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, - const float* skip, const float* gamma, const float* beta, - const float* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, - const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, int ld, - int element_count); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h deleted file mode 100644 index 5e2a92447d2f5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning, - Stream* stream, - V* output, // output tensor - T* skip_input_bias_add_output, // optional output tensor - const T* input, // input tensor - const T* skip, // skip tensor - const V* gamma, // Layer normalization gamma tensor - const V* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int element_count // number of elements in input tensor -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h deleted file mode 100644 index fcfbc8969e498..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "contrib_ops/rocm/bert/layer_norm.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -T maybe2half(float x); - -template <> -float maybe2half(float x) { - return x; -} - -template <> -half maybe2half(float x) { - return __float2half_rn(x); -} - -template -__global__ void SkipLayerNormKernel( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias, - const U epsilon, V* output, T* skip_input_bias_add_output) { - const U reverse_ld = U(1.f / ld); - const int offset = blockIdx.x * ld; - - KeyValuePairSum pair_sum; - // reduce x and x^2 - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = (bias == nullptr) ? static_cast(input[idx]) + static_cast(skip[idx]) : static_cast(input[idx]) + static_cast(skip[idx]) + static_cast(bias[i]); - const U rldval = reverse_ld * val; - thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); - - if (skip_input_bias_add_output != nullptr) { - skip_input_bias_add_output[idx] = static_cast(val); - } - - output[idx] = static_cast(val); - } - - if constexpr (Simplified) { - SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); - return; - } - - LayerNorm(thread_data, ld, offset, beta, gamma, epsilon, output); -} - -// Vectorized kernel -template -__global__ void SkipLayerNormKernelVec( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, - const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput) { - const U reverse_ld = U(1.f / ld); - const int offset = blockIdx.x * ld; - - KeyValuePairSum pair_sum; - // reduce x and x^2 - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - using VecT = aligned_vector; - using VecV = aligned_vector; - if (threadIdx.x * ILP < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - - const VecT input_v = *reinterpret_cast(input + idx); - const VecT skip_v = *reinterpret_cast(skip + idx); - const VecT bias_v = hasBias ? *reinterpret_cast(bias + i) : VecT(); - VecT skip_input_bias_add_output_v, output_v; - -#pragma unroll - for (int k = 0; k < ILP; k++) { - const U val = hasBias ? static_cast(input_v.val[k]) + static_cast(skip_v.val[k]) + static_cast(bias_v.val[k]) : static_cast(input_v.val[k]) + static_cast(skip_v.val[k]); - const U rldval = reverse_ld * val; - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v.val[k] = static_cast(val); - } - thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); - output_v.val[k] = static_cast(val); - } - - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; - } - - *(reinterpret_cast(output + idx)) = output_v; - } - } - - if constexpr (Simplified) { - SimplifiedLayerNormVec(thread_data.value, ld, offset, gamma, epsilon, output); - return; - } - - LayerNormVec(thread_data, ld, offset, beta, gamma, epsilon, output); -} - -// Vectorized kernel -template -__global__ void SkipLayerNormKernelSmall( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, - const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput) { - const U rld = U(1.f / ld); - const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld - - using VecT = aligned_vector; - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - VecT input_v; - if (ILP * threadIdx.x < ld) { - input_v = *reinterpret_cast(input + idx); - const VecT skip_v = *reinterpret_cast(skip + idx); - const VecT bias_v = hasBias ? *reinterpret_cast(bias + threadIdx.x * ILP) : VecT(); - VecT skip_input_bias_add_output_v; - - U rldval_sum = U(0.f); - U rldvalsq_sum = U(0.f); -#pragma unroll - for (int i = 0; i < ILP; i++) { - const U val = hasBias ? static_cast(input_v.val[i]) + static_cast(skip_v.val[i]) + static_cast(bias_v.val[i]) : static_cast(input_v.val[i]) + static_cast(skip_v.val[i]); - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v.val[i] = static_cast(val); - } - - const U rldval = rld * val; - rldval_sum += rldval; - rldvalsq_sum += rldval * val; - input_v.val[i] = static_cast(val); - } - - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; - } - - thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); - } - - if constexpr (Simplified) { - SimplifiedLayerNormSmall(input_v.val, thread_data.value, ld, idx, gamma, epsilon, output); - return; - } - - LayerNormSmall(input_v.val, thread_data, ld, idx, beta, gamma, epsilon, output); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h deleted file mode 100644 index 0391704ce1c56..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::CeilDiv; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -struct SkipLayerNormParams : OpParams { - SkipLayerNormParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, - const T* skip, const V* gamma, const V* beta, - const T* bias, float epsilon, int ld, int element_count) - : OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {} - - std::string Signature() const override { - std::string sig = std::to_string(ld) + "_" + std::to_string(element_count); - return sig; - } - - V* output; - T* skip_input_bias_add_output; - const T* input; - const T* skip; - const V* gamma; - const V* beta; - const T* bias; - float epsilon; - int ld; - int element_count; -}; - -template -Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { - // Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels, - // which could offer better performance in some combinations of ThreadsPerBlock and VecSize. - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->ld <= 8192 && params->ld % VecSize == 0 && - params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))); - SkipLayerNormKernelSmall<<element_count, params->ld)), - dim3(ThreadsPerBlock), - 0, params->StreamHandle()>>>( - params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, - (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); - return HIP_CALL(hipGetLastError()); -} - -template -Status SkipLayerNormRegularOp(const SkipLayerNormParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->ld > 0 && params->ld % VecSize == 0 && - (params->ld >= ThreadsPerBlock * VecSize || - (params->ld < GPU_WARP_SIZE && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))))); - SkipLayerNormKernelVec<<element_count, params->ld)), - dim3(ThreadsPerBlock), - 0, params->StreamHandle()>>>( - params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, - (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); - return HIP_CALL(hipGetLastError()); -} - -template -Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { - bool hasBias = (params->bias == nullptr) ? false : true; - bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true; - const int grid_size = params->element_count / params->ld; - const int block_size = 256; - -#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \ - if (params->ld <= ELEMENTS) { \ - SkipLayerNormKernelSmall<<StreamHandle()>>>( \ - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \ - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, \ - hasBias, hasSkipInputBiasAdditionOutput); \ - break; \ - } - if (0 == (params->ld % 4)) { - do { - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 32, 2) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 32, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 96, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(768, 192, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(1024, 256, 4) - - SkipLayerNormKernel<<StreamHandle()>>>( - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); - } while (0); - } else { - do { - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 64, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 128, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 384, 1) - - SkipLayerNormKernel<<StreamHandle()>>>( - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); - } while (0); - } - return HIP_CALL(hipPeekAtLastError()); -} // namespace rocm - -#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); - -#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 512) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 576) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 640) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 704) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 768) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 832) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 1024) - -template -class SkipLayerNormTunableOp : public TunableOp> { - public: - SkipLayerNormTunableOp() { - this->RegisterOp(SkipLayerNormStaticSelection); - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmallOp) - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegularOp) - - // NOTE: the 1st kernel is SkipLayerNorm Original implementation. - this->SetDefaultId(0); - } -}; - -#undef ADD_OP_FOR_ALL_VEC_SIZE -#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc b/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc deleted file mode 100644 index 6ae8d1202d462..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#include -#include "core/providers/shared_library/provider_api.h" // Include this otherwise Windows build complains Env::Default() missing -#include "core/platform/env_var_utils.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -// The environment variable is for testing purpose only, and it might be removed in the future. -// If you need some option in production, please file a feature request. -constexpr const char* kTransformerOptions = "ORT_TRANSFORMER_OPTIONS"; - -// Initialize the singleton instance -TransformerOptions TransformerOptions::instance; - -const TransformerOptions* TransformerOptions::GetInstance() { - if (!instance.initialized_) { - // We do not use critical section here since it is fine to initialize multiple times by different threads. - int value = ParseEnvironmentVariableWithDefault(kTransformerOptions, 0); - instance.Initialize(value); - - if (value > 0) - std::cout << "ORT_TRANSFORMER_OPTIONS: IsPrecisionMode=" << instance.IsPrecisionMode() - << ",DisablePersistentSoftmax=" << instance.DisablePersistentSoftmax() - << ",DisableHalf2=" << instance.DisableHalf2() - << std::endl; - } - - return &instance; -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.h b/onnxruntime/contrib_ops/rocm/bert/transformer_common.h deleted file mode 100644 index 6816b5b9d07ec..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/transformer_common.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -class TransformerOptions { - public: - static const TransformerOptions* GetInstance(); - - bool IsPrecisionMode() const { return is_precision_mode_; } - - bool DisablePersistentSoftmax() const { return disable_persistent_softmax_; } - - bool DisableHalf2() const { return disable_half2_; } - - void Initialize(int value) { - is_precision_mode_ = (value & 0x01) > 0; - disable_persistent_softmax_ = (value & 0x02) > 0; - disable_half2_ = (value & 0x04) > 0; - initialized_ = true; - } - - private: - // Default is false. If the mode is on, prefer precision than speed. - bool is_precision_mode_{false}; - - // Disable persistent softmax. - bool disable_persistent_softmax_{false}; - - // Disable half2 kernel. - bool disable_half2_{false}; - - bool initialized_{false}; - - static TransformerOptions instance; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh deleted file mode 100644 index d0a0d09fcbae3..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#ifdef USE_COMPOSABLE_KERNEL -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#endif // USE_COMPOSABLE_KERNEL - -#include "contrib_ops/rocm/diffusion/group_norm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#ifdef USE_COMPOSABLE_KERNEL - -using onnxruntime::rocm::CKDataTypeAdaptor; - -// The SiLU function is a special case of Swish function, -// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: -// SiLU(x) = x * sigmoid(x) -// Swish(x) = x * sigmoid(bx) -// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. -using Silu = ck::tensor_operation::element_wise::Swish; -using Pass = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -template -auto GetCKGroupNormNHWCTypeStringAndOps() { - using XDataType = typename CKDataTypeAdaptor::type; - using YDataType = typename CKDataTypeAdaptor::type; - using SaveMeanInvStdDataType = typename CKDataTypeAdaptor::type; - using GammaDataType = float; - using BetaDataType = float; - - using Activation = std::conditional_t; - - std::vector>>> ret; - for (auto&& impl : internal::GetDeviceGroupNormInstances()) { - std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; - auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; - auto invoker = impl->MakeInvokerPointer(); - - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), - "Input skip or bias is not supported by composable kernel."); - if constexpr (WithSilu) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->use_silu, "Silu version only support groupnorm with silu"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->use_silu, "Pass version only support groupnorm without silu"); - } - std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, - params->c, params->channels_per_group, 1}; - std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; - std::vector reduce_dims{1, 2, 4}; - - auto activation = Activation{}; - - auto arg = impl->MakeArgumentPointer(in_lengths, // lengths - in_out_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - in_out_strides, // yStrides - {0, 0}, // saveMeanStrides - {0, 0}, // saveInvStdStrides - reduce_dims, // reduceDims - params->epsilon, - params->src, - params->gamma, - params->beta, - params->dst, - nullptr, - nullptr, - activation); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_group_norm_op))); - } - return ret; -} -#endif // USE_COMPOSABLE_KERNEL - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh deleted file mode 100644 index 68f7d47282845..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ /dev/null @@ -1,130 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#ifdef USE_COMPOSABLE_KERNEL -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" -#include "ck/utility/data_type.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using F16 = ck::half_t; -using F32 = float; - -using Silu = ck::tensor_operation::element_wise::Swish; -using Pass = ck::tensor_operation::element_wise::PassThrough; - -using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface -using ck::tensor_operation::device::DeviceNormalizationFwdImpl; // the implementation - -// See https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1fefd82ed8/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp - -template -using device_normalization_f32_instances = std::tuple< - // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl - // clang-format on - >; - -template -using device_normalization_f16_instances = - // clang-format off - std::tuple < - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl - // clang-format on - >; - -// Use this function to get implementation -template -std::vector>> -GetDeviceGroupNormInstances() { - return {}; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Silu, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Pass, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Silu, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Pass, 5, 3>(); - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu deleted file mode 100644 index ad191314e5e4c..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f16_instances{}); - - return instances; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f16_instances{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu deleted file mode 100644 index ceb53ed442abc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f32_instances{}); - - return instances; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f32_instances{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h deleted file mode 100644 index 7cff640db2f34..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { - GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, - onnxruntime::Stream* ort_stream, - T* output, - T* add_out, - const T* input, - const T* skip, - const T* bias, - const float* gamma, - const float* beta, - float* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_silu, - bool broadcast_skip, - int channels_per_block) - : OpParams(tuning_ctx, ort_stream), - GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, - num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} - - std::string Signature() const override { - std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; - std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; - std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; - std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; - std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + - std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + - skip_suffix + broadcast_suffix + bias_suffix; - return sig; - } -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu deleted file mode 100644 index 142aaf14e8d2d..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// The ROCM kernel is hipified from CUDA kernel. -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" - -#include -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_tunable_op.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchGroupNormKernel( - RocmTuningContext* tuning_ctx, - Stream* ort_stream, - T* output, - T* add_out, - const T* input, - const T* skip, - const T* bias, - const float* gamma, - const float* beta, - void* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_silu, - bool broadcast_skip, - int channels_per_block) { - GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, - reinterpret_cast(workspace), epsilon, batch_size, num_channels, - height, width, num_groups, use_silu, broadcast_skip, channels_per_block); - - if (params.channels_per_block % params.channels_per_group != 0 || - params.channels_per_block > kMaxSize || - (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "GroupNorm in ROCM does not support the input: n=", batch_size, - " h=", height, - " w=", width, - " c=", num_channels, - " groups=", num_groups); - } - - HIP_RETURN_IF_ERROR(hipMemsetAsync( - params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); - - if (tuning_ctx->IsTunableOpEnabled()) { - static GroupNormNHWCTunableOp op; - return op(¶ms); - } - - return GroupNormNHWCStaticSelection(¶ms); -} - -template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, - half* add_out, const half* input, const half* skip, const half* bias, - const float* gamma, const float* beta, void* workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block); - -template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, - float* add_out, const float* input, const float* skip, const float* bias, - const float* gamma, const float* beta, void* workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh deleted file mode 100644 index c6ca16bfdfc80..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "core/providers/rocm/triton_kernel.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#ifdef USE_TRITON_KERNEL - -namespace { - -template -std::string GetGroupNormTritonGroupName() { - std::string ret = "GroupNormTriton_"; - std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; - ret += silu_suffix; - ret += GetDataTypeName(); - return ret; -} - -} // namespace - -template -auto GetTritonGroupNormNHWCTypeStringAndOps() { - std::vector>>> ret; - auto group_name = GetGroupNormTritonGroupName(); - auto* kernel_list = GetOrtTritonKernelByGroup(group_name); - if (kernel_list == nullptr) { - return ret; - } - - for (auto i : *kernel_list) { - // Check params match - auto* metadata = GetOrtTritonKernelMetadata(i); - auto block_size = metadata->constants.at("BLOCK_SIZE"); - auto hw_size = metadata->constants.at("HW_SIZE"); - auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", - params->channels_per_group, ")."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); - if constexpr (WithSilu) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); - } - // Construct args for launch kernel - struct { - const void* src; - const void* skip; - const void* bias; - void* out; - void* add_out; - const void* gamma; - const void* beta; - int hw; - int c; - int c_per_group; - float eps; - bool has_skip; - bool has_bias; - bool broadcast_skip; - } args = { - (const void*)params->src, - (const void*)params->skip, - (const void*)params->bias, - (void*)params->dst, - (void*)params->skip_workspace, - (const void*)params->gamma, - (const void*)params->beta, - params->hw, - params->c, - params->channels_per_group, - params->epsilon, - params->skip != nullptr, - params->bias != nullptr, - params->broadcast_skip, - }; - - // Grid dim is (batch_count, groups, 1) - return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); - }; - ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); - } - return ret; -} - -#endif // USE_TRITON_KERNEL - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py deleted file mode 100644 index 5ba96ebc117f0..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ /dev/null @@ -1,135 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from itertools import product - -import triton -import triton.language as tl - - -@triton.jit -def group_norm_kernel( - input_ptr, - skip_ptr, - bias_ptr, - output_ptr, - add_out_ptr, - gamma_ptr, - beta_ptr, - img_size, - c, - c_per_group, - eps, - has_skip, - has_bias, - broadcast_skip, - BLOCK_SIZE: tl.constexpr, - HW_SIZE: tl.constexpr, - ACTIVATION_SILU: tl.constexpr, -): - row_x = tl.program_id(0) - row_y = tl.program_id(1) - stride = img_size * c - input_ptr += row_x * stride + row_y * c_per_group - output_ptr += row_x * stride + row_y * c_per_group - gamma_ptr += row_y * c_per_group - beta_ptr += row_y * c_per_group - - cols = tl.arange(0, BLOCK_SIZE) - hw = tl.arange(0, HW_SIZE) - offsets = hw[:, None] * c + cols[None, :] - mask = (cols < c_per_group)[None, :] - - bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - if has_skip: - add_out_ptr += row_x * stride + row_y * c_per_group - if broadcast_skip: - broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group - bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) - else: - skip_ptr += row_x * stride + row_y * c_per_group - if has_bias: - bias_ptr += row_y * c_per_group - bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) - - # Calculate mean and variance - _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) - _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) - for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c - a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - if has_skip and not broadcast_skip: - s_ptr = skip_ptr + i * HW_SIZE * c - s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - a += s - if has_bias or broadcast_skip: - a += bias - _sum += a - _square_sum += a * a - if has_skip: - add_y_ptr = add_out_ptr + i * HW_SIZE * c - tl.store(add_y_ptr + offsets, a, mask=mask) - - # Set axis=None (or leave it unspecified) to reduce all axes. - # TODO: In older Triton we have to reduce an axis at a time, but in our case - # for some configs it may have some issue when reducing sequentially along the axes. - group_mean = tl.sum(_sum, axis=None) / (img_size * c_per_group) - group_var = tl.sum(_square_sum, axis=None) / (img_size * c_per_group) - group_mean * group_mean - - rstd = 1 / tl.sqrt(group_var + eps) - - # Normalize and apply linear transformation - gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) - beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) - for i in range(tl.cdiv(img_size, HW_SIZE)): - y_ptr = output_ptr + i * HW_SIZE * c - if has_skip: - add_y_ptr = add_out_ptr + i * HW_SIZE * c - x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - else: - x_ptr = input_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - x_hat = (x - group_mean) * rstd - y = x_hat * gamma + beta - if ACTIVATION_SILU: - y *= tl.sigmoid(y) - tl.store(y_ptr + offsets, y, mask=mask) - - -# We can have more combinations of blocks and hw_sizes, e.g., -# blocks = [16, 32, 64, 128, 256, 512] -# hw_sizes = [8, 16, 32, 64, 128, 256, 512] -# but this will result in too many functions and slow down the compilation. -with_silu = [True, False] -dtypes = ["fp32", "fp16"] -blocks = [16, 32, 64, 128] -hw_sizes = [8, 16, 32, 64, 128, 256] -warps = [1, 2, 4, 8, 16] -name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" -group_pattern = "GroupNormTriton_{}_{}" - - -def get_function_table(): - func_table = [] - - for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): - silu_suffix = "Silu" if silu else "Pass" - name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) - group = group_pattern.format(silu_suffix, dtype) - sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) - kwargs = { - "num_warps": warp, - "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, - } - func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} - func_table.append(func_desc) - return func_table - - -if __name__ == "__main__": - func_table = get_function_table() - for func_desc in func_table: - print(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h deleted file mode 100644 index e6831f764b418..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh" -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh" -#include "contrib_ops/rocm/diffusion/group_norm_triton.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using onnxruntime::rocm::GPU_WARP_SIZE; - -template -void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { - dim3 grid; - - // The number of blocks to compute all the channels. - grid.x = DivUp(params->c, params->channels_per_block); - // The number of blocks to compute all the activations in a given instance. - grid.y = DivUp(params->hw, params->hw_per_block); - // The number of instances. - grid.z = params->n; - -#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ - GroupNormNHWCSumKernel \ - <<StreamHandle()>>>( \ - params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ - params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ - params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ - break; - - // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. - switch (params->threads_per_block) { - case 256: - LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) - case 192: - LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) - case 160: - LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) - case 128: - LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) - case 64: - LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) - default: - ORT_NOT_IMPLEMENTED("Not implemented"); - } -} - -template -Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { - dim3 grid; - grid.x = DivUp(params->c, params->channels_per_block); - grid.y = DivUp(params->hw, params->hw_per_block); - grid.z = params->n; - - GroupNormNHWCSumKernel - <<StreamHandle()>>>( - params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, - params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, - params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); - return HIP_CALL(hipGetLastError()); -} - -template -void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { - dim3 grid; - - // The number of blocks to compute all the channels. - grid.x = DivUp(params->c, params->channels_per_block); - // The number of blocks to compute all the activations in a given instance. - grid.y = DivUp(params->hw, params->hw_per_block); - // The number of instances. - grid.z = params->n; - -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - GroupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ - params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ - params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ - params->hw, params->hw_per_block, params->use_silu); \ - break; - - // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. - switch (params->threads_per_block) { - case 256: - LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) - case 192: - LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) - case 160: - LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) - case 128: - LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) - case 64: - LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) - default: - ORT_NOT_IMPLEMENTED("Not implemented"); - } -} - -template -Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { - dim3 grid; - grid.x = DivUp(params->c, params->channels_per_block); - grid.y = DivUp(params->hw, params->hw_per_block); - grid.z = params->n; - - GroupNormNHWCScaleKernel - <<StreamHandle()>>>( - params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, - params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, - params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, - params->use_silu); - return HIP_CALL(hipGetLastError()); -} - -template -class GroupNormNHWCOp { - public: - Status operator()(const GroupNormNHWCTunableParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, - 0, - GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), - params->StreamHandle())); - auto status = GroupNormNHWCSumOp(params); - ORT_RETURN_IF_ERROR(status); - HIP_RETURN_IF_ERROR(hipGetLastError()); - status = GroupNormNHWCScaleOp(params); - ORT_RETURN_IF_ERROR(status); - HIP_RETURN_IF_ERROR(hipGetLastError()); - return Status::OK(); - } - - Status IsSupported(const GroupNormNHWCTunableParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), - "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, - ") isn't divisible by the number of vector size: ", VecSize); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && - params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), - "Configuration: Threads (", ThreadsPerBlock, "), vector size (", - VecSize, ") is redundant for the number of channels per group: ", - params->channels_per_block); - - return Status::OK(); - } -}; - -template -Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, - 0, - GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), - params->StreamHandle())); - GroupNormNHWCSum(params); - HIP_RETURN_IF_ERROR(hipGetLastError()); - GroupNormNHWCScale(params); - HIP_RETURN_IF_ERROR(hipGetLastError()); - return Status::OK(); -} - -#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ - this->RegisterOp(name{}); \ - this->RegisterOp(name{}); \ - this->RegisterOp(name{}); - -#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 320) - -template -class GroupNormNHWCTunableOp : public TunableOp> { - public: - GroupNormNHWCTunableOp() { - this->RegisterOp(GroupNormNHWCStaticSelection); - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif // USE_COMPOSABLE_KERNEL - -#ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } -}; - -#undef ADD_OP_FOR_ALL_VEC_SIZE -#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc deleted file mode 100644 index 35427a02c631d..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/nn/conv.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - NhwcConv, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/fused_conv.cc b/onnxruntime/contrib_ops/rocm/fused_conv.cc deleted file mode 100644 index 4f3be98d97f80..0000000000000 --- a/onnxruntime/contrib_ops/rocm/fused_conv.cc +++ /dev/null @@ -1,439 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include "core/common/status.h" -#include "core/providers/rocm/nn/conv.h" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace { - -// Copied from hipDNN/library/src/hcc_detail/hipdnn_miopen.cpp -miopenStatus_t _miopenAddTensor( - miopenHandle_t handle, - const void* alpha, - const miopenTensorDescriptor_t aDesc, - const void* A, - const void* beta, - const miopenTensorDescriptor_t cDesc, - void* C, - const void* zero_scalar) { - const miopenTensorOp_t tensorOp = miopenTensorOpAdd; - // Using miopenOpTensor to implement Add operator. - // opnd2 = Add ( 0.0 * opnd0, alpha * opnd1 ) + beta * opnd2 - return miopenOpTensor(handle, tensorOp, - zero_scalar, cDesc, C, - alpha, aDesc, A, - beta, cDesc, C); -} - -} // namespace - -template -struct FNVHash { - uint32_t GetValue() const { return value_; } - - void Hash(const void* in_ptr, size_t nbytes) { - auto ptr = reinterpret_cast(in_ptr); - for (size_t i = 0; i < nbytes; ++i) { - value_ ^= ptr[i]; - value_ *= PRIME; - } - } - - template ::value, size_t>::type = 0> - FNVHash& operator<<(const T& pod) { - Hash(&pod, sizeof(pod)); - return *this; - } - - template - FNVHash& operator<<(const std::vector& pod_array) { - for (const auto& pod : pod_array) { - (*this) << pod; - } - return *this; - } - - void HashTensor(miopenTensorDescriptor_t tdesc) { - int size = 0; - miopenGetTensorDescriptorSize(tdesc, &size); - (*this) << size; - std::vector dims(size); - std::vector strides(size); - miopenDataType_t dtype; - miopenGetTensorDescriptor(tdesc, &dtype, dims.data(), strides.data()); - (*this) << dtype; - (*this) << dims; - (*this) << strides; - } - - void HashConvolutionDescriptor(miopenConvolutionDescriptor_t cdesc) { - int spatial_dim = 1; -#if ROCM_VERSION >= 50500 - MIOPEN_CALL(miopenGetConvolutionSpatialDim(cdesc, &spatial_dim)); - std::vector pads{spatial_dim}; - std::vector strides{spatial_dim}; - std::vector dilations{spatial_dim}; - miopenConvolutionMode_t mode; - MIOPEN_CALL(miopenGetConvolutionNdDescriptor(cdesc, spatial_dim, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)); -#else - // Previous versions of MIOpen doesn't provide API to probe the dimension of a - // miopenConvolutionDescriptor_t, so we have to guess. - // This algorithm is based on a specific behavior of miopenGetConvolutionNdDescriptor, - // which fails when requestedSpatialDim > the convolution's spatial dimension - constexpr const int kMaxSpatialDim = 5; - std::vector pads{kMaxSpatialDim}; - std::vector strides{kMaxSpatialDim}; - std::vector dilations{kMaxSpatialDim}; - miopenConvolutionMode_t mode; - bool spatial_dim_guessed = false; - for (int i = 0; i < kMaxSpatialDim; i++) { - if (miopenStatusSuccess == miopenGetConvolutionNdDescriptor( - cdesc, i, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)) { - spatial_dim_guessed = true; - break; - } - } - ORT_ENFORCE(spatial_dim_guessed, "Failed to guess the actual spatial dimension"); - // Remove the extra dimension - pads.resize(spatial_dim); - strides.resize(spatial_dim); - dilations.resize(spatial_dim); -#endif - (*this) << spatial_dim; - (*this) << pads; - (*this) << strides; - (*this) << dilations; - (*this) << mode; - } - - private: - uint32_t value_ = BASIS; -}; - -template -class FusedConv : public onnxruntime::rocm::Conv { - public: - using Base = onnxruntime::rocm::Conv; - FusedConv(const OpKernelInfo& info) : onnxruntime::rocm::Conv(info) { - std::string activation; - ORT_THROW_IF_ERROR(info.GetAttr("activation", &activation)); - ORT_THROW_IF_ERROR(MapMode(activation)); - MIOPEN_CALL_THROW(miopenCreateActivationDescriptor(&activation_desc_)); - MIOPEN_CALL_THROW(miopenSetActivationDescriptor(activation_desc_, activation_mode_, 0.0, 0.0, 0.0)); - MIOPEN_CALL_THROW(miopenCreateOperatorArgs(&fusion_args_)); - } - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(FusedConv); - - ~FusedConv() { - if (activation_desc_) { - MIOPEN_CALL_THROW(miopenDestroyActivationDescriptor(activation_desc_)); - activation_desc_ = nullptr; - } - - if (fusion_args_) { - miopenDestroyOperatorArgs(fusion_args_); - } - } - - Status ComputeInternal(OpKernelContext* context) const override { - std::lock_guard lock(Base::s_.mutex); - - ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); - if (Base::s_.Y->Shape().Size() == 0) { - return Status::OK(); - } - - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - auto factory = [this](FusedConvFusionData& fusion) { - return this->DoCreateFusionDesc(this->Node().Name(), fusion); - }; - auto& cached_item = plan_cache_.FindOrCreateFusionPlanCache(Hash(), - factory); - bool should_try_fusion_api = cached_item.Validate(this->GetMiopenHandle(context)); - - typedef typename onnxruntime::rocm::ToHipType::MappedType HipT; - const auto alpha = onnxruntime::rocm::Consts::One; - const auto beta = onnxruntime::rocm::Consts::Zero; - IAllocatorUniquePtr workspace = Base::GetWorkSpace(context->GetComputeStream()); - miopenStatus_t fusion_status = miopenStatusNotInitialized; - - if (should_try_fusion_api) { - auto& fusion_info = *cached_item.fusion; - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsConvForward(fusion_args_, - fusion_info.conv_op, - &alpha, - &beta, - Base::s_.w_data)); - if (has_z) { - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, - fusion_info.bias_z_op, - &alpha, - &beta, - Base::s_.z_data)); - } - if (has_b) { - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, - fusion_info.bias_b_op, - &alpha, - &beta, - Base::s_.b_data)); - } - if (activation_desc_) { - const float relu_notused = 0.0; - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsActivForward(fusion_args_, - fusion_info.act_op, - &alpha, - &beta, - relu_notused, - relu_notused, - relu_notused)); - } - fusion_status = miopenExecuteFusionPlan(this->GetMiopenHandle(context), - fusion_info.plan, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.y_tensor, - Base::s_.y_data, - fusion_args_); - } - if (miopenStatusSuccess != fusion_status) { - MIOPEN_RETURN_IF_ERROR(miopenConvolutionForward(this->GetMiopenHandle(context), - &alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.fwd_algo, - &beta, - Base::s_.y_tensor, - Base::s_.y_data, - workspace.get(), - Base::s_.workspace_bytes)); - if (has_b) { - MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), - &alpha, Base::s_.b_tensor, Base::s_.b_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data, - &beta)); - } - if (has_z) { - MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), - &alpha, Base::s_.z_tensor, Base::s_.z_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data, - &beta)); - } - MIOPEN_RETURN_IF_ERROR(miopenActivationForward(this->GetMiopenHandle(context), - activation_desc_, - &alpha, - Base::s_.y_tensor, - Base::s_.y_data, - &beta, - Base::s_.y_tensor, - Base::s_.y_data)); - } - if (Base::s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(onnxruntime::rocm::SliceOutUnwantedOutputSection( - this->Stream(context), - Base::s_.y_data, - Base::s_.y_dims_with_adjusted_pads, - Base::s_.Y->MutableDataRaw(), - Base::s_.y_dims.GetDims(), - Base::s_.slice_starts, - Base::s_.slice_ends, - Base::s_.slice_axes, - Base::s_.element_size)); - } - return Status::OK(); - } - - private: - Status MapMode(const std::string& activaton_mode) { - if (activaton_mode == "Relu") { - activation_mode_ = miopenActivationMode_t::miopenActivationRELU; - } else { - return ORT_MAKE_STATUS( - StatusCategory::ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, - "unsupported conv activation mode \"", activaton_mode, "\""); - } - return Status::OK(); - } - miopenActivationMode_t activation_mode_; - miopenActivationDescriptor_t activation_desc_ = nullptr; - - miopenOperatorArgs_t fusion_args_ = nullptr; - - // MIOpen Fusion API - // TODO: create one fusion descriptor shared by multiple FusedConv - // objects - // - // Considerations: - // How to determine two FusedConv objects may share the same fusion - // descriptor? Hashing x_tensor,conv_desc, etc.? - struct FusedConvFusionData { - miopenFusionPlanDescriptor_t plan = nullptr; - miopenFusionOpDescriptor_t conv_op = nullptr; - miopenFusionOpDescriptor_t bias_b_op = nullptr; - miopenFusionOpDescriptor_t bias_z_op = nullptr; - miopenFusionOpDescriptor_t act_op = nullptr; - - // TODO: There is a potential problem. miopenHandle_t may be destroyed and - // re-created later, sharing the same address. Currently there is any way - // to detect it? - mutable std::unordered_set compiled_on; - - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FusedConvFusionData); - - FusedConvFusionData() {} - ~FusedConvFusionData() { - if (plan) { - miopenDestroyFusionPlan(plan); - } - } - }; - - struct FusionPlanCacheItem { - std::unique_ptr fusion; - Status creation_result; - // TODO: Add a timestamp for eviction - // std::chrono::time_point last_access; - - FusionPlanCacheItem() {} - - miopenStatus_t CompileOnHandle(miopenHandle_t handle) const { - if (!fusion->plan) { - return miopenStatusNotInitialized; - } - auto iter = fusion->compiled_on.find(handle); - if (iter != fusion->compiled_on.end()) { - return miopenStatusSuccess; - } - auto ret = miopenCompileFusionPlan(handle, fusion->plan); - if (miopenStatusSuccess == ret) { - fusion->compiled_on.insert(handle); - } else { - return ret; - } - return miopenStatusSuccess; - } - - bool Validate(miopenHandle_t handle) const { - if (Status::OK() != creation_result) { - return false; - } - if (!fusion || !fusion->plan) { - return false; - } - auto compiling_status = CompileOnHandle(handle); - if (miopenStatusSuccess != compiling_status) { - return false; - } - - return true; - } - }; - - struct FusionPlanCache { - mutable std::mutex mutex; - using HashKey = uint32_t; - std::unordered_map cache_directory_; - - FusionPlanCache() { - } - - FusionPlanCacheItem& FindOrCreateFusionPlanCache(HashKey key, - std::function factory) { - std::lock_guard lock(mutex); - auto iter = cache_directory_.find(key); - if (iter == cache_directory_.end()) { - cache_directory_[key].fusion = std::make_unique(); - cache_directory_[key].creation_result = factory(*cache_directory_[key].fusion); - if (Status::OK() != cache_directory_[key].creation_result) { - cache_directory_[key].fusion.reset(); - } - } - return cache_directory_[key]; - } - }; - - static FusionPlanCache plan_cache_; - - Status DoCreateFusionDesc(const std::string& node_name, FusedConvFusionData& fusion) const { - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - MIOPEN_RETURN_IF_ERROR(miopenCreateFusionPlan(&fusion.plan, - miopenVerticalFusion, - Base::s_.x_tensor)); - auto status = miopenCreateOpConvForward(fusion.plan, &fusion.conv_op, Base::s_.conv_desc, Base::s_.w_desc); - if (status == miopenStatusUnsupportedOp) { - auto msg = MakeString("MIOpen does not support the conv fusion for node \"", - node_name, "\", fallback to unfused implementation."); - LOGS_DEFAULT(WARNING) << msg; - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, msg); - } - MIOPEN_RETURN_IF_ERROR(status); - - if (has_z) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, - &fusion.bias_z_op, - Base::s_.z_tensor)); - } - if (has_b) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, - &fusion.bias_b_op, - Base::s_.b_tensor)); - } - if (activation_desc_) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpActivationForward(fusion.plan, - &fusion.act_op, - activation_mode_)); - } - return Status::OK(); - } - - uint32_t Hash() const { - FNVHash hash; - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - hash.HashTensor(Base::s_.x_tensor); - hash.HashConvolutionDescriptor(Base::s_.conv_desc); - hash.HashTensor(Base::s_.w_desc); - if (has_z) { - hash.HashTensor(Base::s_.z_tensor); - } - if (has_b) { - hash.HashTensor(Base::s_.b_tensor); - } - if (activation_desc_) { - hash << static_cast(activation_mode_); - } - return hash.GetValue(); - } -}; - -template -typename FusedConv::FusionPlanCache FusedConv::plan_cache_; - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - FusedConv, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - FusedConv); - -REGISTER_KERNEL_TYPED(float); -REGISTER_KERNEL_TYPED(MLFloat16); -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu deleted file mode 100644 index 3539f32252944..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/common.h" -#include "core/common/float16.h" -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; -using namespace onnxruntime::rocm::tunable::blas; - -class GemmFloat8 final : public RocmKernel { - public: - GemmFloat8(const OpKernelInfo& info) : RocmKernel(info) { - transA_ = info.GetAttrOrDefault("transA", 0); - transB_ = info.GetAttrOrDefault("transB", 0); - dtype_ = info.GetAttrOrDefault("dtype", onnx::TensorProto_DataType_FLOAT16); - alpha_ = info.GetAttrOrDefault("alpha", 1); - beta_ = info.GetAttrOrDefault("beta", 0); - } - Status ComputeInternal(OpKernelContext* ctx) const override; - - private: -#if !defined(DISABLE_FLOAT8_TYPES) - template - Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const; - template - Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const; - - template - [[nodiscard]] inline auto* GetOp() const { - using OpT = GemmFloat8TunableOp; - if (tunable_op_) { - return static_cast(tunable_op_.get()); - } - - auto create = std::make_unique(); // avoid new - tunable_op_ = std::shared_ptr(create.release(), [](void* ptr) { - auto release = std::unique_ptr(); // avoid delete - release.reset(static_cast(ptr)); - }); - - return static_cast(tunable_op_.get()); - } -#endif - - float alpha_; - float beta_; - bool transA_; - bool transB_; - int64_t dtype_; - - // fully type erased - mutable std::shared_ptr tunable_op_; -}; - -Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { -#if defined(DISABLE_FLOAT8_TYPES) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DISABLE_FLOAT8_TYPES"); -#else - const Tensor* A = ctx->Input(0); - const Tensor* B = ctx->Input(1); - const Tensor* C = ctx->Input(2); // bias - const Tensor* scale_a = ctx->Input(3); - const Tensor* scale_b = ctx->Input(4); - const Tensor* scale_y = ctx->Input(5); - - auto a_shape = A->Shape(); - auto b_shape = B->Shape(); - ORT_ENFORCE(a_shape.NumDimensions() == 2); - ORT_ENFORCE(b_shape.NumDimensions() == 2); - - auto m = !transA_ ? a_shape[0] : a_shape[1]; - auto k = !transA_ ? a_shape[1] : a_shape[0]; - ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatible - auto n = !transB_ ? b_shape[1] : b_shape[0]; - - TensorShapeVector output_shape = {m, n}; - Tensor* Y = ctx->Output(0, output_shape); - - ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose"); - ORT_ENFORCE(dtype_ == onnx::TensorProto_DataType_FLOAT16, "ROCm GemmFloat8 only supports output float16"); - ORT_ENFORCE(C == nullptr, "ROCm GemmFloat8 does not support bias input"); - ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling"); - - if (A->IsDataType()) { - return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); - } else if (A->IsDataType()) { - return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); - } else if (B->IsDataType()) { - return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); - } else if (B->IsDataType()) { - return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); - } - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unhandled type combination of GemmFloat8"); -#endif -} - -#if !defined(DISABLE_FLOAT8_TYPES) -template -Status GemmFloat8::ComputeFp8Fp16Fp16( - OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const { - ORT_ENFORCE(A->IsDataType() && scale_a->IsDataType() && B->IsDataType()); - - onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; - params.tuning_ctx = GetTuningContext(); - params.stream = ctx->GetComputeStream(); - params.handle = GetHipblasHandle(ctx); - params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - - params.m = m; - params.n = n; - params.k = k; - - params.a = static_cast(A->DataRaw()); - params.lda = transA_ ? m : k; - params.scale_a = alpha_; - params.scale_a_dev = static_cast(scale_a->DataRaw()); - - params.b = static_cast(B->DataRaw()); - params.ldb = transB_ ? k : n; - params.scale_b = 1.0f; // NOTE: not used - params.scale_b_dev = nullptr; // NOTE: not used - - params.c = static_cast(C->MutableDataRaw()); - params.ldc = n; - params.scale_c = 1.0f; // NOTE: not implemented - params.scale_c_dev = nullptr; // NOTE: not implemented - - if (!transA_ && !transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && !transB_) { - ORT_NOT_IMPLEMENTED("transA is not implemented"); - } else if (!transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transB is not implemented"); - } else if (transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); -} - -template -Status GemmFloat8::ComputeFp16Fp8Fp16( - OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const { - ORT_ENFORCE(A->IsDataType() && B->IsDataType() && scale_b->IsDataType()); - - onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; - params.tuning_ctx = GetTuningContext(); - params.stream = ctx->GetComputeStream(); - params.handle = GetHipblasHandle(ctx); - params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - - params.m = m; - params.n = n; - params.k = k; - - params.a = static_cast(A->DataRaw()); - params.lda = transA_ ? m : k; - params.scale_a = 1.0f; // NOTE: not used - params.scale_a_dev = nullptr; // NOTE: not used - - params.b = static_cast(B->DataRaw()); - params.ldb = transB_ ? k : n; - params.scale_b = alpha_; - params.scale_b_dev = static_cast(scale_b->DataRaw()); - - params.c = static_cast(C->MutableDataRaw()); - params.ldc = n; - params.scale_c = 1.0f; // NOTE: not implemented - params.scale_c_dev = nullptr; // NOTE: not implemented - - if (!transA_ && !transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && !transB_) { - ORT_NOT_IMPLEMENTED("transA is not implemented"); - } else if (!transA_ && transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); -} -#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() -#else -#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() -#endif - -ONNX_OPERATOR_KERNEL_EX( - GemmFloat8, - kMSDomain, - 1, - kRocmExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) - .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) - .TypeConstraint("TR", BuildKernelDefConstraints()) - .TypeConstraint("TS", BuildKernelDefConstraints()), - GemmFloat8); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh deleted file mode 100644 index b545eb1f2a149..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#if defined(USE_COMPOSABLE_KERNEL) - -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/utility/functional3.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif - -#if !defined(DISABLE_FLOAT8_TYPES) -#include "core/common/float8.h" -#endif -#include "core/providers/rocm/tunable/gemm_common.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -constexpr bool always_false = false; - -template -struct Scale { - constexpr const static bool is_pack2_invocable = true; - constexpr const static bool is_pack4_invocable = true; - - explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {} - - template - __forceinline__ __host__ __device__ Y fast_type_convert(X x) const { - static_assert(always_false, "not implemented"); - (void)x; - } - - template <> - __forceinline__ __host__ __device__ ck::half_t fast_type_convert(ck::f8_t x) const { - // https://github.com/ROCmSoftwarePlatform/triton/blob/0cc3f8b84a16892396f6e08a04991034d67e32b1/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L220-L233 - constexpr const uint16_t mask = 0x7fff; - constexpr const uint16_t sign_mask = 0x8000; - constexpr const uint16_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x2000; - } else if constexpr (std::is_same_v) { - return 0x1c00; - } - }(); - - uint8_t x_u8 = reinterpret_cast(x); - uint16_t x_u16 = static_cast(x_u8) << 8; - uint16_t exp = (x_u16 & mask) >> 1; - uint16_t y = (x_u16 & sign_mask) | (exp + exp_compensate); - return reinterpret_cast(y); - } - - __forceinline__ __host__ __device__ void operator()(ck::half_t& y, const ck::f8_t& x) const { - float scale = scale_value_ * (*dev_scale_ptr_); - y = ck::type_convert(scale * fast_type_convert(x)); - } - - __forceinline__ __host__ __device__ void operator()(ck::half2_t& ys, const ck::f8x2_t& xs) const { - float scale = scale_value_ * (*dev_scale_ptr_); - constexpr const uint32_t mask = 0x7fff7fff; - constexpr const uint32_t sign_mask = 0x80008000; - constexpr const uint32_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x20002000; - } else if constexpr (std::is_same_v) { - return 0x1c001c00; - } - }(); - - const uchar2& x2_u8 = reinterpret_cast(xs); - uchar4 x{0, x2_u8.x, 0, x2_u8.y}; - uint32_t x_u32 = reinterpret_cast(x); - - uint32_t exp = (x_u32 & mask) >> 1; - uint32_t v = (x_u32 & sign_mask) | (exp + exp_compensate); - ys = scale * reinterpret_cast(v); - } - - __forceinline__ __host__ __device__ void operator()(ck::half4_t& ys, const ck::f8x4_t& xs) const { - float scale = scale_value_ * (*dev_scale_ptr_); - constexpr const uint32_t mask = 0x7fff7fff; - constexpr const uint32_t sign_mask = 0x80008000; - constexpr const uint32_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x20002000; - } else if constexpr (std::is_same_v) { - return 0x1c001c00; - } - }(); - - uint32_t xs_u32 = reinterpret_cast(xs); - uint32_t x_u32_0 = __byte_perm(xs_u32, 0, 0x1504); - uint32_t x_u32_1 = __byte_perm(xs_u32, 0, 0x3726); - uint32_t exp_0 = (x_u32_0 & mask) >> 1; - uint32_t exp_1 = (x_u32_1 & mask) >> 1; - uint32_t v_0 = (x_u32_0 & sign_mask) | (exp_0 + exp_compensate); - uint32_t v_1 = (x_u32_1 & sign_mask) | (exp_1 + exp_compensate); - uint64_t v = v_0 | uint64_t(v_1) << 32; - ys = scale * reinterpret_cast(v); - } - - float scale_value_; - const float* const dev_scale_ptr_; -}; -#endif - -namespace blas { - -template -struct GemmFloat8Params : tunable::OpParams { - std::string Signature() const override { - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); - } - - hipblasHandle_t handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - float scale_a{}; - const float* scale_a_dev{}; - const TA* a; - int64_t lda; - float scale_b{}; - const float* scale_b_dev{}; - const TB* b; - int64_t ldb; - TC* c; - float scale_c{}; - const float* scale_c_dev{}; - int64_t ldc; -}; - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using Nop = ck::tensor_operation::element_wise::PassThrough; - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, Nop, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, Nop, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, Nop>>>& instances); - -template -auto CreateOp(float scale, const float* dev_scale) { - if constexpr (std::is_same_v) { - return Scale(scale, dev_scale); - } else if constexpr (std::is_same_v) { - return Scale(scale, dev_scale); - } else { - return Nop{}; - } -} - -template -auto GetCKF8SplitKGemmTypeStringAndOps() { - using CKTA = typename CKDataTypeAdaptor::type; - using CKTB = typename CKDataTypeAdaptor::type; - using CKTC = typename CKDataTypeAdaptor::type; - - using CKLayoutA = typename CKBlasOpAdaptor::type; - using CKLayoutB = typename CKBlasOpAdaptor::type; - - using OpA = std::conditional_t, Scale, Nop>; - using OpB = std::conditional_t, Scale, Nop>; - using OpC = std::conditional_t, Scale, Nop>; - - using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< - CKLayoutA, CKLayoutB, Row, - CKTA, CKTB, CKTC, - OpA, OpB, OpC>; - - std::vector>>> ret; - - for (auto num_split : {1, 4, 16, 64}) { - std::vector> instances{}; - if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances); - } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances); - } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances); - } else { - static_assert(always_false, "no instances for the type combination"); - LOGS_DEFAULT(FATAL) << "no instances for the type combination"; - } - for (auto&& impl : instances) { - auto type_string = std::to_string(ret.size()) + "_" + impl->GetTypeString() + "_SplitK" + std::to_string(num_split); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmFloat8Params* params) -> Status { - OpA op_a = CreateOp(params->scale_a, params->scale_a_dev); - OpB op_b = CreateOp(params->scale_b, params->scale_b_dev); - OpC op_c = CreateOp(params->scale_c, params->scale_c_dev); - - auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, params->ldc, - op_a, op_b, op_c, num_split); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); - } - } - return ret; -} - -#endif // USE_COMPOSABLE_KERNEL - -template -class GemmFloat8TunableOp : public TunableOp> { - public: - GemmFloat8TunableOp() { -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#else - ORT_ENFORCE(false, "CK is required to support GemmFloat8 computing"); -#endif // USE_COMPOSABLE_KERNEL - } -}; - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu deleted file mode 100644 index 4c691dd18f2e9..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { - -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -namespace internal { -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); -} - -namespace internal { -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances); - -// TODO: The first try of derivation does not going well due to various constraints. -// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( -// std::vector, PassThrough, PassThrough>>>& instances); - -// TODO: The first try of derivation does not going well due to various constraints. -// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( -// std::vector, PassThrough, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, PassThrough, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); - // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: -} - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, PassThrough, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); - // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: -} - -namespace internal { -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); -} - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu deleted file mode 100644 index 49463e58886f8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu +++ /dev/null @@ -1,97 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// The derived version is simply double BBlockTransferSrcScalarPerVector and adjust other values correspondingly -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 8, 4, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 8, 4, 32, 32, 3, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 8, 4, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 12, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 16, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 8, 4, 32, 32, 3, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 8, 4, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu deleted file mode 100644 index 236e5555051fc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu +++ /dev/null @@ -1,80 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu deleted file mode 100644 index 1a0d45df82a71..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu deleted file mode 100644 index a0628802ec09e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu +++ /dev/null @@ -1,97 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc deleted file mode 100644 index 7dbb24463961e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ /dev/null @@ -1,347 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" - -using namespace onnxruntime::common; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasAdd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMulConj); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskBiasDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, NGramRepeatBlock); - -// These ops were experimental ops in onnx domain which have been removed now. We add them here as -// contrib ops to maintain backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LongformerAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LongformerAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Sampling); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_float, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, double_double_double, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_MLFloat16, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Inverse); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MatMulNBits); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GemmFloat8); - -#ifdef ENABLE_ATEN -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); -#endif - -#ifdef ENABLE_TRAINING_OPS -// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or -// 2). this is needed by inference for other purpose. -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather); -#endif - -#ifdef ORT_USE_NCCL -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll); -#endif - -template <> -KernelCreateInfo BuildKernelCreateInfo() { - KernelCreateInfo info; - return info; -} - -// clang-format off -Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { - static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // These ops were experimental ops in onnx domain which have been removed now. We add them here as - // contrib ops to maintain backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // TransposedMatMul is still here for backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - -#ifdef ENABLE_ATEN - BuildKernelCreateInfo, -#endif - -#ifdef ENABLE_TRAINING_OPS - // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or - // 2). this is needed by inference for other purpose. - BuildKernelCreateInfo, -#endif - -#ifdef ORT_USE_NCCL - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, -#endif - - }; - - for (auto& function_table_entry : function_table) { - KernelCreateInfo info = function_table_entry(); - if (info.kernel_def != nullptr) { // filter disabled entries where type is void - ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); - } - } - - return Status::OK(); -} -// clang-format on - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h deleted file mode 100644 index db9a5d4fcd83e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -Status RegisterRocmContribKernels(KernelRegistry& kernel_registry); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/tools/ci_build/github/linux/build_rocm_c_api_package.sh b/tools/ci_build/github/linux/build_rocm_c_api_package.sh deleted file mode 100755 index 3ea90c73342a5..0000000000000 --- a/tools/ci_build/github/linux/build_rocm_c_api_package.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash - -set -e -u -x - -usage() { echo "Usage: $0 -S -B -V [-H ] " 1>&2; exit 1; } - -ROCM_HOME=/opt/rocm - -while getopts S:B:V:H:I:P: parameter_Option; do - case "${parameter_Option}" in - S) SOURCE_DIR=${OPTARG};; - B) BINARY_DIR=${OPTARG};; - V) ROCM_VERSION=${OPTARG};; - H) ROCM_HOME=${OPTARG};; - I) IMAGE=${OPTARG};; - P) PYTHON_BIN=${OPTARG};; - *) usage ;; - esac -done - -EXIT_CODE=1 - -docker run -e SYSTEM_COLLECTIONURI --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --user $UID:$(id -g $USER) \ - -e NIGHTLY_BUILD \ - --volume $SOURCE_DIR:/onnxruntime_src \ - --volume $BINARY_DIR:/build \ - --volume /data/models:/build/models:ro \ - --volume /data/onnx:/data/onnx:ro \ - --workdir /onnxruntime_src \ - $IMAGE \ - /bin/bash -c "${PYTHON_BIN:-python} /onnxruntime_src/tools/ci_build/build.py --config Release --build_dir /build --parallel --use_rocm --use_binskim_compliant_compile_flags --rocm_version=$ROCM_VERSION --rocm_home $ROCM_HOME --nccl_home $ROCM_HOME --build_shared_lib --skip_submodule_sync --skip_tests --cmake_extra_defines FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER && cd /build/Release && make install DESTDIR=/build/installed" - - -EXIT_CODE=$? - -set -e -exit $EXIT_CODE diff --git a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh deleted file mode 100755 index 0be64d96f3a34..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -set -e -x - -# version -ROCM_VERSION=6.2.3 - -while getopts "r:" parameter_Option -do case "${parameter_Option}" -in -r) ROCM_VERSION=${OPTARG};; -esac -done - -tee /etc/yum.repos.d/amdgpu.repo <