From 71970cd05dca8fd9f879d982e2c51f76d985d1c9 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 2 Mar 2026 13:03:34 -0800 Subject: [PATCH 1/4] [ET-VK][ez] Use tree reduction in q8ta_linear_gemv shader Replace the serial O(WGS) reduction loop with a tree reduction pattern (O(log2(WGS))). Previously, only thread 0 summed all 64 partial accumulators sequentially. Now all threads participate in a classic halving reduction, matching the pattern already used in linear_q4gsw_coop.glsl. Authored by Claude. Differential Revision: [D94949137](https://our.internmc.facebook.com/intern/diff/D94949137/) ghstack-source-id: 346524552 Pull Request resolved: https://github.com/pytorch/executorch/pull/17792 --- .../graph/ops/glsl/q8ta_linear_gemv.glsl | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl index 241fc1845bf..becff2ab9ab 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl @@ -66,7 +66,12 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #include "linear_int_weight_sums_load.glslh" #include "linear_fp_bias_load.glslh" -shared Int32Accum partial_accums[WGS]; +// Array-of-arrays shared memory layout: partial_accums[lid][tile_n4]. +// Each element is exactly 16 bytes (ivec4). This avoids the Samsung S25 +// (Adreno 830) driver bug triggered by the original Int32Accum struct layout, +// where barrier() only invalidated the first 16-byte component of each +// 32-byte struct slot, leaving subsequent components stale. +shared ivec4 partial_accums[WGS][TILE_N4]; void main() { const int lid = int(gl_LocalInvocationID.z); @@ -104,19 +109,29 @@ void main() { } } - partial_accums[lid] = out_accum; + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + partial_accums[lid][tile_n4] = out_accum.data[0][tile_n4]; + } memoryBarrierShared(); barrier(); - // Only the first thread writes the result - if (lid == 0) { - for (int i = 1; i < WGS; ++i) { + // Tree reduction: O(log2(WGS)). + for (int i = WGS / 2; i > 0; i /= 2) { + if (lid < i) { [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { - out_accum.data[0][tile_n4] += - partial_accums[i].data[0][tile_n4]; + partial_accums[lid][tile_n4] += partial_accums[lid + i][tile_n4]; } } + memoryBarrierShared(); + barrier(); + } + + // Only the first thread writes the result + if (lid == 0) { + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + out_accum.data[0][tile_n4] = partial_accums[0][tile_n4]; + } FPPerOutChannelParams weight_scales_tile; load_weight_scales_tile(weight_scales_tile, n4); From f60943c6884c4ae7cc151cc2ba7ae5c6545763b6 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 2 Mar 2026 13:03:39 -0800 Subject: [PATCH 2/4] [ET-VK][qconv] Enable im2col to handle grouped convolution Previously, the im2col + pointwise GEMM path (`q8ta_conv2d_im2col`) only supported non-grouped convolutions (groups=1). This diff extends it to handle grouped convolutions as well, providing significant speedups on Mali GPUs. The key changes are: **PW GEMM shader (`q8ta_conv2d_pw.glsl`)**: Added `K4_per_group` and `OC4_per_group` as push constants. The shader now computes a group index from the output channel block (`group_idx = oc_block_idx / OC4_per_group`) and offsets the im2col input read by `group_idx * K4_per_group`. For non-grouped cases (groups=1), `group_idx` is always 0, so behavior is unchanged. **PW node (`Q8taConv2dPW.cpp`)**: `add_q8ta_conv2d_pw_node` now accepts a `groups` parameter (default=1) and computes `K4_per_group` and `OC4_per_group` internally from the input/output tensor dimensions. `K4_per_group` and `OC4_per_group` were previously specialization constants; they are now push constants to avoid shader variant explosion when groups varies. **Im2col node (`Q8taConv2dIm2Col.cpp`)**: Removed the `groups == 1` assertion from `add_q8ta_im2col_node`. The im2col shader already handles groups correctly (each group's K range is contiguous in the output buffer). The `q8ta_conv2d_im2col` operator now passes the groups value through to the PW node. **Dispatch heuristic (`Q8taConv2d.cpp`)**: Updated `q8ta_conv2d` with device-aware dispatch. On Mali, im2col is used for all eligible cases (grouped and ungrouped) since it provides 1.2-3.6x speedups. On Adreno, im2col is only used for ungrouped convolutions (groups=1) where in_channels_per_group >= 32 or spatial_out <= 4096, since grouped convolutions show 0.7-0.95x regression with im2col. The heuristic uses `graph.device_is_mali()` to select the path. **Tests (`test_q8ta_conv2d.cpp`)**: Updated im2col test eligibility from `groups == 1 && channels.in % 4 == 0` to `in_channels_per_group % 4 == 0`, enabling im2col testing for grouped cases. Added SceneX v9 256x256 grouped convolution configs. Differential Revision: [D94949480](https://our.internmc.facebook.com/intern/diff/D94949480/) ghstack-source-id: 346525921 Pull Request resolved: https://github.com/pytorch/executorch/pull/17793 --- backends/vulkan/runtime/graph/ComputeGraph.h | 5 +++ .../graph/ops/glsl/q8ta_conv2d_pw.glsl | 17 ++++++--- .../runtime/graph/ops/impl/Q8taConv2d.cpp | 21 +++++++--- .../runtime/graph/ops/impl/Q8taConv2d.h | 22 ++++++++++- .../graph/ops/impl/Q8taConv2dIm2Col.cpp | 10 +++-- .../runtime/graph/ops/impl/Q8taConv2dPW.cpp | 38 ++++++++----------- backends/vulkan/runtime/vk_api/Device.cpp | 4 +- .../test/custom_ops/test_q8ta_conv2d.cpp | 29 +++++++++++--- 8 files changed, 100 insertions(+), 46 deletions(-) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index a7c8cffffd1..5ce84dd705b 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -661,6 +661,11 @@ class ComputeGraph final { inline bool device_is_adreno() { return context_->adapter_ptr()->device_type() == vkapi::DeviceType::ADRENO; } + + inline bool device_is_mali() { + return context_->adapter_ptr()->device_type() == vkapi::DeviceType::MALI; + } + const std::string& device_name() { return context()->adapter_ptr()->device_name(); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl index d408b7ca9b8..fc063579c45 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl @@ -55,13 +55,14 @@ layout(push_constant) uniform restrict Block { int input_zp; float output_inv_scale; int output_zp; + int K4_per_group; + int OC4_per_group; }; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "apply_bias", "1")} ${layout_declare_spec_const(C, "int", "activation_type", "0")} -${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")} // Layout specialization constants ${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} @@ -124,12 +125,18 @@ void main() { } } - // Compute initial input tile index - // Input has same spatial layout, channel dimension iterates from 0 - int input_idx = oh * inp_h_stride + ow_block_idx * inp_w_stride; + // Compute group index from output channel block + const int group_idx = oc_block_idx / OC4_per_group; + + // Compute initial input tile index with group offset + // For grouped im2col, each group's K range starts at group_idx * K4_per_group + // For non-grouped (groups=1), group_idx is always 0 so offset is 0 + int input_idx = oh * inp_h_stride + + ow_block_idx * inp_w_stride + + group_idx * K4_per_group; // Main accumulation loop over K dimension - for (int k4 = 0; k4 < conv2d_params_K4_per_group; k4++) { + for (int k4 = 0; k4 < K4_per_group; k4++) { // Load packed int8 input tile (TILE_M4=1, TILE_K4=1) // Each int contains 4 packed int8s (one per width position in the tile) ivec4 int8_input_tile = t_packed_int8_input[input_idx]; diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index d1a4840fbba..f6e89bef03d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -430,12 +430,21 @@ void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { const int64_t W_out = graph.size_at(-1, output); const int64_t spatial_out = H_out * W_out; - // Use im2col when the channel depth is sufficient for tiled GEMM to win, or - // when the output spatial area is small enough that the im2col buffer stays - // manageable. For large spatial outputs with few channels, the im2col buffer - // becomes too large and the general shader is more efficient. - const bool use_im2col = groups == 1 && in_channels_per_group % 4 == 0 && - (in_channels_per_group >= 64 || spatial_out <= 4096); + // Im2col requires input channels per group to be a multiple of 4 + const bool im2col_eligible = in_channels_per_group % 4 == 0; + + bool use_im2col = false; + if (graph.device_is_mali()) { + // On Mali, im2col is faster than the general shader across the board. + use_im2col = im2col_eligible; + } else { + // Default: on Adreno and unknown GPU architectures, im2col is only + // beneficial for ungrouped convolutions with sufficient channel depth or + // small spatial output. For grouped convolutions, the general shader is + // more efficient (0.7-0.95x regression measured on Adreno). + use_im2col = im2col_eligible && groups == 1 && + (in_channels_per_group >= 32 || spatial_out <= 4096); + } if (use_im2col) { q8ta_conv2d_im2col(graph, args); diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 2779a7445a8..6da98fbef74 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -121,7 +121,27 @@ void add_q8ta_conv2d_pw_node( const ValueRef bias_data, const ValueRef packed_bias, const uint32_t activation_type, - const ValueRef packed_int8_output); + const ValueRef packed_int8_output, + const int32_t groups = 1); + +std::vector calculate_q8ta_im2col_sizes( + ComputeGraph* graph, + const ValueRef& input, + const ValueRef& output, + const ValueRef& kernel_size, + const ValueRef& groups); + +void add_q8ta_im2col_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef packed_int8_output, + const ValueRef packed_int8_im2col, + const int32_t zp); void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector& args); diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp index 161b5e8fc24..b43fe9eacc6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp @@ -127,9 +127,8 @@ void add_q8ta_im2col_node( dilation, groups); - // At the moment, the im2col path only supports non-grouped convolutions - VK_CHECK_COND(conv_params.groups == 1); - // The implementation also requires that input channels is a multiple of 4 + // The implementation requires that input channels per group is a multiple of + // 4 VK_CHECK_COND(conv_params.in_channels_per_group % 4 == 0); std::string kernel_name = "q8ta_im2col"; @@ -257,6 +256,8 @@ void q8ta_conv2d_im2col( zp); // Step 2: Perform pointwise convolution on the im2col result + const int32_t groups_val = graph.extract_scalar(groups); + add_q8ta_conv2d_pw_node( graph, packed_int8_im2col, @@ -270,7 +271,8 @@ void q8ta_conv2d_im2col( bias_data, packed_bias, activation_type_val, - packed_int8_output); + packed_int8_output, + groups_val); } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp index 1872e8796de..e27e0699dac 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp @@ -200,39 +200,40 @@ void add_q8ta_conv2d_pw_node( const ValueRef bias_data, const ValueRef packed_bias, const uint32_t activation_type, - const ValueRef packed_int8_output) { - // Validate packed dim info for input and output tensors - // To maximize performance, the input tensor must be in 4W4C layout + const ValueRef packed_int8_output, + const int32_t groups) { VK_CHECK_COND(q8ta_conv2d_check_4w4c_packed_dim_info( graph.packed_dim_info_of(packed_int8_input))); - // However, the requirements for output tensor layout is flexible VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( graph.packed_dim_info_of(packed_int8_output))); - // Validate dtype is kInt8x4 VK_CHECK_COND(graph.dtype_of(packed_int8_input) == vkapi::kInt8x4); VK_CHECK_COND(graph.dtype_of(packed_int8_output) == vkapi::kInt8x4); + // Compute K4_per_group and OC4_per_group from tensor dimensions and groups + // Input K dim (dim -3) = K_per_group * groups for grouped im2col, or IC for + // non-grouped. Either way, K4_per_group = div_up_4(K_dim / groups). + const int32_t K_dim = graph.size_at(-3, packed_int8_input); + const int32_t OC = graph.size_at(-3, packed_int8_output); + const int32_t K4_per_group = + static_cast(utils::div_up_4(K_dim / groups)); + const int32_t OC4_per_group = + static_cast(utils::div_up_4(OC / groups)); + float input_scale_val = graph.extract_scalar(input_scale); int32_t input_zp_val = graph.extract_scalar(input_zp); float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); int32_t output_zp_val = graph.extract_scalar(output_zp); - uint32_t apply_bias = 1; - if (graph.val_is_none(bias_data)) { - apply_bias = 0; - } - - // Get input channel count for K4_per_group - const uint32_t IC = graph.size_at(-3, packed_int8_input); - const uint32_t K4_per_group = utils::div_up_4(IC); - + uint32_t apply_bias = graph.val_is_none(bias_data) ? 0u : 1u; std::vector push_constants = { PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + PushConstantDataInfo(&K4_per_group, sizeof(K4_per_group)), + PushConstantDataInfo(&OC4_per_group, sizeof(OC4_per_group)), }; const bool use_hw_dot = @@ -241,17 +242,13 @@ void add_q8ta_conv2d_pw_node( use_hw_dot ? "q8ta_conv2d_pw" : "q8ta_conv2d_pw_fallback"; add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); - // Pass metadata for both output and input tensors vkapi::ParamsBindList param_buffers = { graph.buffer_meta_ubo(packed_int8_output), graph.buffer_meta_ubo(packed_int8_input)}; - // Build spec constants: apply_bias, activation_type + layout constants vkapi::SpecVarList spec_constants = { apply_bias, activation_type, - K4_per_group, - // Layout specialization constants graph.hashed_layout_of(packed_int8_output), graph.hashed_layout_of(packed_int8_input), }; @@ -261,7 +258,6 @@ void add_q8ta_conv2d_pw_node( VK_KERNEL_FROM_STR(kernel_name), pick_q8ta_conv2d_pw_global_wg_size, pick_q8ta_conv2d_pw_local_wg_size, - // Inputs and Outputs {{packed_int8_output, vkapi::kWrite}, {{packed_int8_input, packed_weight, @@ -269,13 +265,9 @@ void add_q8ta_conv2d_pw_node( packed_weight_scales, packed_bias}, vkapi::kRead}}, - // Shader params buffers param_buffers, - // Push Constants push_constants, - // Specialization Constants spec_constants, - // Resize args {})); } diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index ef3d57be90d..dbe8a73651c 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -132,7 +132,9 @@ PhysicalDevice::PhysicalDevice( device_type = DeviceType::SWIFTSHADER; } else if (device_name.find("nvidia") != std::string::npos) { device_type = DeviceType::NVIDIA; - } else if (device_name.find("mali") != std::string::npos) { + } else if ( + device_name.find("mali") != std::string::npos || + device_name.find("immortalis") != std::string::npos) { device_type = DeviceType::MALI; } } diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index 41ddd389aa8..9f0273a5b83 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -237,9 +237,9 @@ std::vector generate_quantized_conv2d_easy_cases() { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); - // Test im2col implementation for non-grouped convolutions with input - // channels that are a multiple of 4 and stride_w == 1 - if (config.groups == 1 && config.channels.in % 4 == 0) { + // Test im2col implementation when input channels per group is a + // multiple of 4 + if ((config.channels.in / config.groups) % 4 == 0) { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, @@ -379,6 +379,21 @@ static std::vector generate_quantized_conv2d_test_cases() { Padding(2, 2), Dilation(1, 1), 4}, + // SceneX v9 grouped convolutions (large spatial) + {OutInChannels(128, 128), + InputSize2D(256, 256), + KernelSize(5, 5), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 4}, + {OutInChannels(64, 64), + InputSize2D(256, 256), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 2}, // Deep channels + small spatial (ResNet50 stage 5 bottleneck) {OutInChannels(512, 512), InputSize2D(7, 7), @@ -426,9 +441,11 @@ static std::vector generate_quantized_conv2d_test_cases() { int8_memory_layout, /*impl_selector=*/"general")); - // Test im2col implementation for non-grouped convolutions with input - // channels that are a multiple of 4 and stride_w == 1 - if (config.groups == 1 && config.channels.in % 4 == 0) { + // Test im2col implementation when input channels per group is a + // multiple of 4 + const int64_t in_channels_per_group = + config.channels.in / config.groups; + if (in_channels_per_group % 4 == 0) { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, From b459a80a3be7b264636ecd1de79e07e7c37274ec Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 2 Mar 2026 13:03:44 -0800 Subject: [PATCH 3/4] [ET-VK][qconv] Add dynamic PACKED_INT8_CONV2D memory layout for device-adaptive conv2d Performance testing of quantized int8 convolutions reveals that different algorithms perform better on different GPU architectures: im2col is faster on Mali while direct convolution is faster on Adreno. The optimal memory layout differs per algorithm (4C for im2col, 4C1W for direct convolution). This introduces a new "dynamic" memory layout PACKED_INT8_CONV2D that is serialized at export time and resolved to a concrete layout at runtime based on the device's GPU architecture. The resolution logic in ResolveLayouts.cpp mirrors the im2col vs direct convolution decision in Q8taConv2d.cpp. Differential Revision: [D94949134](https://our.internmc.facebook.com/intern/diff/D94949134/) ghstack-source-id: 346525918 Pull Request resolved: https://github.com/pytorch/executorch/pull/17794 --- backends/vulkan/op_registry.py | 2 +- backends/vulkan/runtime/ResolveLayouts.cpp | 206 ++++++++++++++++++ backends/vulkan/runtime/ResolveLayouts.h | 26 +++ backends/vulkan/runtime/VulkanBackend.cpp | 37 +++- backends/vulkan/serialization/schema.fbs | 2 + .../serialization/vulkan_graph_schema.py | 2 + backends/vulkan/utils.py | 20 +- 7 files changed, 289 insertions(+), 6 deletions(-) create mode 100644 backends/vulkan/runtime/ResolveLayouts.cpp create mode 100644 backends/vulkan/runtime/ResolveLayouts.h diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b18bf3b81c6..62997ea956f 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -839,7 +839,7 @@ def register_q8ta_conv_pw_op(): def register_q8ta_conv2d_ops(): return OpFeatures( inputs_storage=[ - utils.PACKED_INT8_4C1W_BUFFER, # input + utils.PACKED_INT8_CONV2D_BUFFER, # input utils.NO_STORAGE, # input_scale (non tensor) utils.NO_STORAGE, # input_zero_point (non tensor) utils.NO_STORAGE, # weight (prepacked) diff --git a/backends/vulkan/runtime/ResolveLayouts.cpp b/backends/vulkan/runtime/ResolveLayouts.cpp new file mode 100644 index 00000000000..7f3b5934123 --- /dev/null +++ b/backends/vulkan/runtime/ResolveLayouts.cpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +namespace vkcompute { + +namespace { + +using VkGraphPtr = const vkgraph::VkGraph*; +using OpCallPtr = const vkgraph::OperatorCall*; +using VkValuePtr = const vkgraph::VkValue*; +using VkTensorPtr = const vkgraph::VkTensor*; +using UIntVector = const flatbuffers::Vector*; + +bool is_dynamic_layout(const vkgraph::VkMemoryLayout layout) { + return layout == vkgraph::VkMemoryLayout::PACKED_INT8_CONV2D; +} + +bool is_packed_int8_layout(vkgraph::VkMemoryLayout layout) { + switch (layout) { + case vkgraph::VkMemoryLayout::PACKED_INT8_4W4C: + case vkgraph::VkMemoryLayout::PACKED_INT8_4H4W: + case vkgraph::VkMemoryLayout::PACKED_INT8_4W: + case vkgraph::VkMemoryLayout::PACKED_INT8_4C: + case vkgraph::VkMemoryLayout::PACKED_INT8_4C1W: + return true; + default: + return false; + } +} + +vkgraph::VkMemoryLayout get_resolved_layout( + uint32_t fb_id, + VkGraphPtr flatbuffer, + const std::unordered_map& + memory_layout_overrides) { + auto it = memory_layout_overrides.find(fb_id); + if (it != memory_layout_overrides.end()) { + return it->second; + } + VkValuePtr value = flatbuffer->values()->Get(fb_id); + if (value->value_type() != vkgraph::GraphTypes::VkTensor) { + return vkgraph::VkMemoryLayout::DEFAULT_LAYOUT; + } + return value->value_as_VkTensor()->memory_layout(); +} + +void resolve_dynamic_args( + VkGraphPtr flatbuffer, + OpCallPtr op_call, + std::unordered_map& + memory_layout_overrides) { + // Find the first arg tensor with a non-dynamic packed int8 layout + vkgraph::VkMemoryLayout resolved_layout = + vkgraph::VkMemoryLayout::DEFAULT_LAYOUT; + bool found = false; + for (int i = 0; i < op_call->args()->size(); ++i) { + const uint32_t fb_id = static_cast(op_call->args()->Get(i)); + VkValuePtr value = flatbuffer->values()->Get(fb_id); + if (value->value_type() != vkgraph::GraphTypes::VkTensor) { + continue; + } + vkgraph::VkMemoryLayout layout = + get_resolved_layout(fb_id, flatbuffer, memory_layout_overrides); + if (is_packed_int8_layout(layout)) { + resolved_layout = layout; + found = true; + break; + } + } + + if (!found) { + return; + } + + // Override all args whose resolved layout is still dynamic + for (int i = 0; i < op_call->args()->size(); ++i) { + const uint32_t fb_id = static_cast(op_call->args()->Get(i)); + vkgraph::VkMemoryLayout layout = + get_resolved_layout(fb_id, flatbuffer, memory_layout_overrides); + if (is_dynamic_layout(layout)) { + memory_layout_overrides[fb_id] = resolved_layout; + } + } +} + +void resolve_q8ta_conv2d( + VkGraphPtr flatbuffer, + OpCallPtr op_call, + ComputeGraph* compute_graph, + std::unordered_map& + memory_layout_overrides) { + // q8ta_conv2d args layout: + // 0: input, 1: input_scale, 2: input_zp, 3: weight, 4: weight_sums, + // 5: weight_scales, 6: output_scale, 7: output_zp, 8: bias, + // 9: kernel_size, 10: stride, 11: padding, 12: dilation, 13: groups, + // 14: activation, 15: output + + const uint32_t input_fb_id = static_cast(op_call->args()->Get(0)); + const uint32_t groups_fb_id = static_cast(op_call->args()->Get(13)); + const uint32_t output_fb_id = static_cast(op_call->args()->Get(15)); + + // Only resolve if the input tensor has a dynamic layout + VkTensorPtr input_tensor = + flatbuffer->values()->Get(input_fb_id)->value_as_VkTensor(); + if (!is_dynamic_layout(input_tensor->memory_layout())) { + return; + } + + // Extract groups value + VkValuePtr groups_value = flatbuffer->values()->Get(groups_fb_id); + const int64_t groups = groups_value->value_as_Int()->int_val(); + + // Extract input tensor dimensions + UIntVector input_dims = input_tensor->dims(); + const int64_t input_ndim = input_dims->size(); + const int64_t in_channels = input_dims->Get(input_ndim - 3); + const int64_t in_channels_per_group = in_channels / groups; + + // Extract output tensor dimensions + VkTensorPtr output_tensor = + flatbuffer->values()->Get(output_fb_id)->value_as_VkTensor(); + UIntVector output_dims = output_tensor->dims(); + const int64_t output_ndim = output_dims->size(); + const int64_t H_out = output_dims->Get(output_ndim - 2); + const int64_t W_out = output_dims->Get(output_ndim - 1); + const int64_t spatial_out = H_out * W_out; + + // Replicate the im2col decision logic from Q8taConv2d.cpp + const bool im2col_eligible = in_channels_per_group % 4 == 0; + + bool use_im2col = false; + if (compute_graph->device_is_mali()) { + use_im2col = im2col_eligible; + } else { + use_im2col = im2col_eligible && groups == 1 && + (in_channels_per_group >= 32 || spatial_out <= 4096); + } + + if (use_im2col) { + memory_layout_overrides[input_fb_id] = + vkgraph::VkMemoryLayout::PACKED_INT8_4C; + } else { + memory_layout_overrides[input_fb_id] = + vkgraph::VkMemoryLayout::PACKED_INT8_4C1W; + } +} + +void resolve_q8ta_conv2d_dw( + VkGraphPtr flatbuffer, + OpCallPtr op_call, + std::unordered_map& + memory_layout_overrides) { + const uint32_t input_fb_id = static_cast(op_call->args()->Get(0)); + + // Only override if not already overridden by a previous op + if (memory_layout_overrides.count(input_fb_id) > 0) { + return; + } + + // Only resolve if the input tensor has a dynamic layout + VkTensorPtr input_tensor = + flatbuffer->values()->Get(input_fb_id)->value_as_VkTensor(); + if (!is_dynamic_layout(input_tensor->memory_layout())) { + return; + } + + memory_layout_overrides[input_fb_id] = + vkgraph::VkMemoryLayout::PACKED_INT8_4C1W; +} + +} // namespace + +void resolve_memory_layouts( + const vkgraph::VkGraph* flatbuffer, + ComputeGraph* compute_graph, + std::unordered_map& + memory_layout_overrides) { + // First, handle ops where input memory layout is impactful for performance + for (const auto* op_call : *(flatbuffer->chain())) { + const std::string op_name = op_call->name()->str(); + + if (op_name == "et_vk.q8ta_conv2d.default") { + resolve_q8ta_conv2d( + flatbuffer, op_call, compute_graph, memory_layout_overrides); + } else if (op_name == "et_vk.q8ta_conv2d_dw.default") { + resolve_q8ta_conv2d_dw(flatbuffer, op_call, memory_layout_overrides); + } + } + // Then, try to ensure ops use the same memory layout whenever possible. + for (const auto* op_call : *(flatbuffer->chain())) { + resolve_dynamic_args(flatbuffer, op_call, memory_layout_overrides); + } +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/ResolveLayouts.h b/backends/vulkan/runtime/ResolveLayouts.h new file mode 100644 index 00000000000..332d192d4af --- /dev/null +++ b/backends/vulkan/runtime/ResolveLayouts.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +namespace vkcompute { + +class ComputeGraph; + +void resolve_memory_layouts( + const vkgraph::VkGraph* flatbuffer, + ComputeGraph* compute_graph, + std::unordered_map& + memory_layout_overrides); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7f7afffcf57..d4eeb9b1dd4 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include @@ -32,6 +33,7 @@ #include #include #include +#include #include namespace executorch { @@ -146,8 +148,13 @@ utils::GPUMemoryLayout get_memory_layout( return utils::kPackedInt8_4H4W; case vkgraph::VkMemoryLayout::PACKED_INT8_4W: return utils::kPackedInt8_4W; + case vkgraph::VkMemoryLayout::PACKED_INT8_4C: + return utils::kPackedInt8_4C; case vkgraph::VkMemoryLayout::PACKED_INT8_4C1W: return utils::kPackedInt8_4C1W; + case vkgraph::VkMemoryLayout::PACKED_INT8_CONV2D: + // Fallback for unresolved dynamic layout + return utils::kPackedInt8_4C1W; default: break; } @@ -205,6 +212,8 @@ class GraphBuilder { std::vector loaded_buffers_from_map_; std::vector ref_mapping_; + std::unordered_map + memory_layout_overrides_; public: explicit GraphBuilder( @@ -217,7 +226,13 @@ class GraphBuilder { constant_data_(constant_data), named_data_map_(named_data_map), loaded_buffers_from_map_(), - ref_mapping_() {} + ref_mapping_(), + memory_layout_overrides_() {} + + void resolve_layouts() { + resolve_memory_layouts( + flatbuffer_, compute_graph_, memory_layout_overrides_); + } void resize(uint32_t size) { ref_mapping_.resize(size, INT32_MAX); @@ -235,6 +250,21 @@ class GraphBuilder { return ref_mapping_[fb_id]; } + utils::GPUMemoryLayout get_resolved_memory_layout( + const uint32_t fb_id, + VkTensorPtr tensor_fb, + const std::vector& dims_vector) { + auto it = memory_layout_overrides_.find(fb_id); + if (it != memory_layout_overrides_.end()) { + return get_memory_layout(it->second); + } + + if (tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT) { + return compute_graph_->suggested_memory_layout(dims_vector); + } + return get_memory_layout(tensor_fb->memory_layout()); + } + void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) { const vkapi::ScalarType& dtype = get_scalar_type(tensor_fb->datatype()); utils::StorageType storage_type = @@ -246,9 +276,7 @@ class GraphBuilder { const std::vector dims_vector(dims_fb->cbegin(), dims_fb->cend()); utils::GPUMemoryLayout memory_layout = - tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT - ? compute_graph_->suggested_memory_layout(dims_vector) - : get_memory_layout(tensor_fb->memory_layout()); + get_resolved_memory_layout(fb_id, tensor_fb, dims_vector); ValueRef ref; if (tensor_fb->constant_id() >= 0) { @@ -593,6 +621,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { GraphBuilder builder( compute_graph, flatbuffer_graph, constant_data, named_data_map); + builder.resolve_layouts(); builder.build_graph(); compute_graph->prepare(); diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index 36f9feaa580..92cee15cfe8 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -43,7 +43,9 @@ enum VkMemoryLayout : ubyte { PACKED_INT8_4W4C = 3, PACKED_INT8_4H4W = 4, PACKED_INT8_4W = 5, + PACKED_INT8_4C = 6, PACKED_INT8_4C1W = 8, + PACKED_INT8_CONV2D = 9, DEFAULT_LAYOUT = 255, } diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 845a59a4dff..c53111c2092 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -51,7 +51,9 @@ class VkMemoryLayout(IntEnum): PACKED_INT8_4W4C = 3 PACKED_INT8_4H4W = 4 PACKED_INT8_4W = 5 + PACKED_INT8_4C = 6 PACKED_INT8_4C1W = 8 + PACKED_INT8_CONV2D = 9 DEFAULT_LAYOUT = 255 def __str__(self) -> str: diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index dde9aaac973..2a3c3910c48 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -606,6 +606,7 @@ def node_has_target(node: Any, target: str): VkMemoryLayout.PACKED_INT8_4H4W, VkMemoryLayout.PACKED_INT8_4W, VkMemoryLayout.PACKED_INT8_4C1W, + VkMemoryLayout.PACKED_INT8_CONV2D, } universal_memory_layout_set: Set[VkMemoryLayout] = ( @@ -622,6 +623,7 @@ def node_has_target(node: Any, target: str): VkMemoryLayout.PACKED_INT8_4W4C: 2, VkMemoryLayout.PACKED_INT8_4H4W: 0, VkMemoryLayout.PACKED_INT8_4C1W: 2, + VkMemoryLayout.PACKED_INT8_CONV2D: 2, } @@ -686,6 +688,11 @@ def from_repr( packed_dim=2, packed_dim_block_size=4 if is_buffer else 16, ) + elif memory_layout == VkMemoryLayout.PACKED_INT8_CONV2D: + return cls( + packed_dim=2, + packed_dim_block_size=4, + ) else: raise ValueError(f"Unknown memory layout: {memory_layout}") @@ -742,6 +749,10 @@ def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageEx elif layout == VkMemoryLayout.PACKED_INT8_4H4W: height = (height + 3) // 4 width = (width + 3) // 4 + elif layout == VkMemoryLayout.PACKED_INT8_CONV2D: + # Use conservative extents (same as 4W4C) since this is buffer-only + width = (width + 3) // 4 + channels = (channels + 3) // 4 else: raise RuntimeError(f"Unsupported memory layout {layout}") @@ -1175,8 +1186,15 @@ def filter_invalid_reprs_for_node_list( PACKED_INT8_4W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W}, set()) PACKED_INT8_4C1W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4C1W}, set()) +PACKED_INT8_CONV2D_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_CONV2D}, set()) + PACKED_INT8_CHANNELS_PACKED_BUFFER = TensorRepSet( - {VkMemoryLayout.PACKED_INT8_4W4C, VkMemoryLayout.PACKED_INT8_4C1W}, set() + { + VkMemoryLayout.PACKED_INT8_4W4C, + VkMemoryLayout.PACKED_INT8_4C1W, + VkMemoryLayout.PACKED_INT8_CONV2D, + }, + set(), ) From ef87d609c0163ae62e6213ed5bb0d05e5840f7a7 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 2 Mar 2026 13:03:48 -0800 Subject: [PATCH 4/4] [ET-VK][testing] Add GPU device name override for on-device model tests Add the ability to override the Vulkan device name at runtime so that device-adaptive code paths (e.g. memory layout selection) can be tested on hardware that doesn't match the overridden device type. PhysicalDevice::override_device_name() and Adapter::override_device_name() are added behind VULKAN_DEBUG. The device type detection logic is refactored into a reusable determine_device_type() helper to avoid duplication between the constructor and the override function. All test binaries in fb/test/models/ (classification, greenscreen, scenex, skin_seg) now accept --gpu_name to invoke the override before loading the model. The Skycastle CI workflows are updated to re-run classification and greenscreen tests with --gpu_name Mali-G715 in addition to the default run. Differential Revision: [D94949136](https://our.internmc.facebook.com/intern/diff/D94949136/) ghstack-source-id: 346525920 Pull Request resolved: https://github.com/pytorch/executorch/pull/17795 --- backends/vulkan/runtime/vk_api/Adapter.cpp | 4 +++ backends/vulkan/runtime/vk_api/Adapter.h | 2 ++ backends/vulkan/runtime/vk_api/Device.cpp | 42 ++++++++++++++++------ backends/vulkan/runtime/vk_api/Device.h | 3 ++ 4 files changed, 40 insertions(+), 11 deletions(-) diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index 8eae5ff35e4..82482a5b7c4 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -375,6 +375,10 @@ void Adapter::submit_cmd( VK_CHECK(vkQueueSubmit(device_queue.handle, 1u, &submit_info, fence)); } +void Adapter::override_device_name(const std::string& new_name) { + physical_device_.override_device_name(new_name); +} + std::string Adapter::stringize() const { std::stringstream ss; diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index 89beb5c3a5c..3c503deab70 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -306,6 +306,8 @@ class Adapter final { VkSemaphore wait_semaphore = VK_NULL_HANDLE, VkSemaphore signal_semaphore = VK_NULL_HANDLE); + void override_device_name(const std::string& new_name); + std::string stringize() const; friend std::ostream& operator<<(std::ostream&, const Adapter&); }; diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index dbe8a73651c..cb6a54dc489 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -21,6 +21,25 @@ namespace vkcompute { namespace vkapi { +namespace { + +DeviceType determine_device_type(const std::string& device_name) { + if (device_name.find("adreno") != std::string::npos) { + return DeviceType::ADRENO; + } else if (device_name.find("swiftshader") != std::string::npos) { + return DeviceType::SWIFTSHADER; + } else if (device_name.find("nvidia") != std::string::npos) { + return DeviceType::NVIDIA; + } else if ( + device_name.find("mali") != std::string::npos || + device_name.find("immortalis") != std::string::npos) { + return DeviceType::MALI; + } + return DeviceType::UNKNOWN; +} + +} // namespace + PhysicalDevice::PhysicalDevice( VkInstance instance_handle, VkPhysicalDevice physical_device_handle) @@ -126,17 +145,7 @@ PhysicalDevice::PhysicalDevice( device_name.begin(), [](unsigned char c) { return std::tolower(c); }); - if (device_name.find("adreno") != std::string::npos) { - device_type = DeviceType::ADRENO; - } else if (device_name.find("swiftshader") != std::string::npos) { - device_type = DeviceType::SWIFTSHADER; - } else if (device_name.find("nvidia") != std::string::npos) { - device_type = DeviceType::NVIDIA; - } else if ( - device_name.find("mali") != std::string::npos || - device_name.find("immortalis") != std::string::npos) { - device_type = DeviceType::MALI; - } + device_type = determine_device_type(device_name); } void PhysicalDevice::query_extensions_vk_1_0() { @@ -292,6 +301,17 @@ void PhysicalDevice::query_extensions_vk_1_1() { vkGetPhysicalDeviceProperties2(handle, &properties2); } +void PhysicalDevice::override_device_name(const std::string& new_name) { + device_name = new_name; + std::transform( + device_name.begin(), + device_name.end(), + device_name.begin(), + [](unsigned char c) { return std::tolower(c); }); + + device_type = determine_device_type(device_name); +} + // // DeviceHandle // diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index ac5e381e46a..9fa413b2457 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -84,6 +84,9 @@ struct PhysicalDevice final { private: void query_extensions_vk_1_0(); void query_extensions_vk_1_1(); + + public: + void override_device_name(const std::string& new_name); }; struct DeviceHandle final {