[ET-VK][qconv] Enable im2col to handle grouped convolution#17793
[ET-VK][qconv] Enable im2col to handle grouped convolution#17793meta-codesync[bot] merged 1 commit intogh/SS-JIA/454/basefrom
Conversation
Previously, the im2col + pointwise GEMM path (`q8ta_conv2d_im2col`) only supported non-grouped convolutions (groups=1). This diff extends it to handle grouped convolutions as well, providing significant speedups on Mali GPUs. The key changes are: **PW GEMM shader (`q8ta_conv2d_pw.glsl`)**: Added `K4_per_group` and `OC4_per_group` as push constants. The shader now computes a group index from the output channel block (`group_idx = oc_block_idx / OC4_per_group`) and offsets the im2col input read by `group_idx * K4_per_group`. For non-grouped cases (groups=1), `group_idx` is always 0, so behavior is unchanged. **PW node (`Q8taConv2dPW.cpp`)**: `add_q8ta_conv2d_pw_node` now accepts a `groups` parameter (default=1) and computes `K4_per_group` and `OC4_per_group` internally from the input/output tensor dimensions. `K4_per_group` and `OC4_per_group` were previously specialization constants; they are now push constants to avoid shader variant explosion when groups varies. **Im2col node (`Q8taConv2dIm2Col.cpp`)**: Removed the `groups == 1` assertion from `add_q8ta_im2col_node`. The im2col shader already handles groups correctly (each group's K range is contiguous in the output buffer). The `q8ta_conv2d_im2col` operator now passes the groups value through to the PW node. **Dispatch heuristic (`Q8taConv2d.cpp`)**: Updated `q8ta_conv2d` with device-aware dispatch. On Mali, im2col is used for all eligible cases (grouped and ungrouped) since it provides 1.2-3.6x speedups. On Adreno, im2col is only used for ungrouped convolutions (groups=1) where in_channels_per_group >= 32 or spatial_out <= 4096, since grouped convolutions show 0.7-0.95x regression with im2col. The heuristic uses `graph.device_is_mali()` to select the path. **Tests (`test_q8ta_conv2d.cpp`)**: Updated im2col test eligibility from `groups == 1 && channels.in % 4 == 0` to `in_channels_per_group % 4 == 0`, enabling im2col testing for grouped cases. Added SceneX v9 256x256 grouped convolution configs. Differential Revision: [D94949480](https://our.internmc.facebook.com/intern/diff/D94949480/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17793
Note: Links to docs will display an error until the docs builds have been completed. ❌ 12 New Failures, 2 Unrelated FailuresAs of commit 0979460 with merge base ae41854 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
e754d74
into
gh/SS-JIA/454/base
Previously, the im2col + pointwise GEMM path (`q8ta_conv2d_im2col`) only supported non-grouped convolutions (groups=1). This diff extends it to handle grouped convolutions as well, providing significant speedups on Mali GPUs. The key changes are: **PW GEMM shader (`q8ta_conv2d_pw.glsl`)**: Added `K4_per_group` and `OC4_per_group` as push constants. The shader now computes a group index from the output channel block (`group_idx = oc_block_idx / OC4_per_group`) and offsets the im2col input read by `group_idx * K4_per_group`. For non-grouped cases (groups=1), `group_idx` is always 0, so behavior is unchanged. **PW node (`Q8taConv2dPW.cpp`)**: `add_q8ta_conv2d_pw_node` now accepts a `groups` parameter (default=1) and computes `K4_per_group` and `OC4_per_group` internally from the input/output tensor dimensions. `K4_per_group` and `OC4_per_group` were previously specialization constants; they are now push constants to avoid shader variant explosion when groups varies. **Im2col node (`Q8taConv2dIm2Col.cpp`)**: Removed the `groups == 1` assertion from `add_q8ta_im2col_node`. The im2col shader already handles groups correctly (each group's K range is contiguous in the output buffer). The `q8ta_conv2d_im2col` operator now passes the groups value through to the PW node. **Dispatch heuristic (`Q8taConv2d.cpp`)**: Updated `q8ta_conv2d` with device-aware dispatch. On Mali, im2col is used for all eligible cases (grouped and ungrouped) since it provides 1.2-3.6x speedups. On Adreno, im2col is only used for ungrouped convolutions (groups=1) where in_channels_per_group >= 32 or spatial_out <= 4096, since grouped convolutions show 0.7-0.95x regression with im2col. The heuristic uses `graph.device_is_mali()` to select the path. **Tests (`test_q8ta_conv2d.cpp`)**: Updated im2col test eligibility from `groups == 1 && channels.in % 4 == 0` to `in_channels_per_group % 4 == 0`, enabling im2col testing for grouped cases. Added SceneX v9 256x256 grouped convolution configs. Differential Revision: [D94949480](https://our.internmc.facebook.com/intern/diff/D94949480/) ghstack-source-id: 346525921 Pull Request resolved: #17793
Previously, the im2col + pointwise GEMM path (`q8ta_conv2d_im2col`) only supported non-grouped convolutions (groups=1). This diff extends it to handle grouped convolutions as well, providing significant speedups on Mali GPUs. The key changes are: **PW GEMM shader (`q8ta_conv2d_pw.glsl`)**: Added `K4_per_group` and `OC4_per_group` as push constants. The shader now computes a group index from the output channel block (`group_idx = oc_block_idx / OC4_per_group`) and offsets the im2col input read by `group_idx * K4_per_group`. For non-grouped cases (groups=1), `group_idx` is always 0, so behavior is unchanged. **PW node (`Q8taConv2dPW.cpp`)**: `add_q8ta_conv2d_pw_node` now accepts a `groups` parameter (default=1) and computes `K4_per_group` and `OC4_per_group` internally from the input/output tensor dimensions. `K4_per_group` and `OC4_per_group` were previously specialization constants; they are now push constants to avoid shader variant explosion when groups varies. **Im2col node (`Q8taConv2dIm2Col.cpp`)**: Removed the `groups == 1` assertion from `add_q8ta_im2col_node`. The im2col shader already handles groups correctly (each group's K range is contiguous in the output buffer). The `q8ta_conv2d_im2col` operator now passes the groups value through to the PW node. **Dispatch heuristic (`Q8taConv2d.cpp`)**: Updated `q8ta_conv2d` with device-aware dispatch. On Mali, im2col is used for all eligible cases (grouped and ungrouped) since it provides 1.2-3.6x speedups. On Adreno, im2col is only used for ungrouped convolutions (groups=1) where in_channels_per_group >= 32 or spatial_out <= 4096, since grouped convolutions show 0.7-0.95x regression with im2col. The heuristic uses `graph.device_is_mali()` to select the path. **Tests (`test_q8ta_conv2d.cpp`)**: Updated im2col test eligibility from `groups == 1 && channels.in % 4 == 0` to `in_channels_per_group % 4 == 0`, enabling im2col testing for grouped cases. Added SceneX v9 256x256 grouped convolution configs. Differential Revision: [D94949480](https://our.internmc.facebook.com/intern/diff/D94949480/) ghstack-source-id: 346525921 Pull Request resolved: #17793
Stack from ghstack (oldest at bottom):
Previously, the im2col + pointwise GEMM path (
q8ta_conv2d_im2col) onlysupported non-grouped convolutions (groups=1). This diff extends it to handle
grouped convolutions as well, providing significant speedups on Mali GPUs.
The key changes are:
PW GEMM shader (
q8ta_conv2d_pw.glsl): AddedK4_per_groupandOC4_per_groupas push constants. The shader now computes a group index fromthe output channel block (
group_idx = oc_block_idx / OC4_per_group) andoffsets the im2col input read by
group_idx * K4_per_group. For non-groupedcases (groups=1),
group_idxis always 0, so behavior is unchanged.PW node (
Q8taConv2dPW.cpp):add_q8ta_conv2d_pw_nodenow accepts agroupsparameter (default=1) and computesK4_per_groupandOC4_per_groupinternally from the input/output tensor dimensions.
K4_per_groupandOC4_per_groupwere previously specialization constants; they are now pushconstants to avoid shader variant explosion when groups varies.
Im2col node (
Q8taConv2dIm2Col.cpp): Removed thegroups == 1assertionfrom
add_q8ta_im2col_node. The im2col shader already handles groups correctly(each group's K range is contiguous in the output buffer). The
q8ta_conv2d_im2coloperator now passes the groups value through to the PW node.
Dispatch heuristic (
Q8taConv2d.cpp): Updatedq8ta_conv2dwithdevice-aware dispatch. On Mali, im2col is used for all eligible cases (grouped
and ungrouped) since it provides 1.2-3.6x speedups. On Adreno, im2col is only
used for ungrouped convolutions (groups=1) where in_channels_per_group >= 32 or
spatial_out <= 4096, since grouped convolutions show 0.7-0.95x regression with
im2col. The heuristic uses
graph.device_is_mali()to select the path.Tests (
test_q8ta_conv2d.cpp): Updated im2col test eligibility fromgroups == 1 && channels.in % 4 == 0toin_channels_per_group % 4 == 0,enabling im2col testing for grouped cases. Added SceneX v9 256x256 grouped
convolution configs.
Differential Revision: D94949480