From 31f1d0bdc8361f6c02951907ed865be116844f99 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 2 Mar 2026 13:03:41 -0800 Subject: [PATCH] [ET-VK][qconv] Add dynamic PACKED_INT8_CONV2D memory layout for device-adaptive conv2d Performance testing of quantized int8 convolutions reveals that different algorithms perform better on different GPU architectures: im2col is faster on Mali while direct convolution is faster on Adreno. The optimal memory layout differs per algorithm (4C for im2col, 4C1W for direct convolution). This introduces a new "dynamic" memory layout PACKED_INT8_CONV2D that is serialized at export time and resolved to a concrete layout at runtime based on the device's GPU architecture. The resolution logic in ResolveLayouts.cpp mirrors the im2col vs direct convolution decision in Q8taConv2d.cpp. Differential Revision: [D94949134](https://our.internmc.facebook.com/intern/diff/D94949134/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 2 +- backends/vulkan/runtime/ResolveLayouts.cpp | 206 ++++++++++++++++++ backends/vulkan/runtime/ResolveLayouts.h | 26 +++ backends/vulkan/runtime/VulkanBackend.cpp | 37 +++- backends/vulkan/serialization/schema.fbs | 2 + .../serialization/vulkan_graph_schema.py | 2 + backends/vulkan/utils.py | 20 +- 7 files changed, 289 insertions(+), 6 deletions(-) create mode 100644 backends/vulkan/runtime/ResolveLayouts.cpp create mode 100644 backends/vulkan/runtime/ResolveLayouts.h diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b18bf3b81c6..62997ea956f 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -839,7 +839,7 @@ def register_q8ta_conv_pw_op(): def register_q8ta_conv2d_ops(): return OpFeatures( inputs_storage=[ - utils.PACKED_INT8_4C1W_BUFFER, # input + utils.PACKED_INT8_CONV2D_BUFFER, # input utils.NO_STORAGE, # input_scale (non tensor) utils.NO_STORAGE, # input_zero_point (non tensor) utils.NO_STORAGE, # weight (prepacked) diff --git a/backends/vulkan/runtime/ResolveLayouts.cpp b/backends/vulkan/runtime/ResolveLayouts.cpp new file mode 100644 index 00000000000..7f3b5934123 --- /dev/null +++ b/backends/vulkan/runtime/ResolveLayouts.cpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +namespace vkcompute { + +namespace { + +using VkGraphPtr = const vkgraph::VkGraph*; +using OpCallPtr = const vkgraph::OperatorCall*; +using VkValuePtr = const vkgraph::VkValue*; +using VkTensorPtr = const vkgraph::VkTensor*; +using UIntVector = const flatbuffers::Vector*; + +bool is_dynamic_layout(const vkgraph::VkMemoryLayout layout) { + return layout == vkgraph::VkMemoryLayout::PACKED_INT8_CONV2D; +} + +bool is_packed_int8_layout(vkgraph::VkMemoryLayout layout) { + switch (layout) { + case vkgraph::VkMemoryLayout::PACKED_INT8_4W4C: + case vkgraph::VkMemoryLayout::PACKED_INT8_4H4W: + case vkgraph::VkMemoryLayout::PACKED_INT8_4W: + case vkgraph::VkMemoryLayout::PACKED_INT8_4C: + case vkgraph::VkMemoryLayout::PACKED_INT8_4C1W: + return true; + default: + return false; + } +} + +vkgraph::VkMemoryLayout get_resolved_layout( + uint32_t fb_id, + VkGraphPtr flatbuffer, + const std::unordered_map& + memory_layout_overrides) { + auto it = memory_layout_overrides.find(fb_id); + if (it != memory_layout_overrides.end()) { + return it->second; + } + VkValuePtr value = flatbuffer->values()->Get(fb_id); + if (value->value_type() != vkgraph::GraphTypes::VkTensor) { + return vkgraph::VkMemoryLayout::DEFAULT_LAYOUT; + } + return value->value_as_VkTensor()->memory_layout(); +} + +void resolve_dynamic_args( + VkGraphPtr flatbuffer, + OpCallPtr op_call, + std::unordered_map& + memory_layout_overrides) { + // Find the first arg tensor with a non-dynamic packed int8 layout + vkgraph::VkMemoryLayout resolved_layout = + vkgraph::VkMemoryLayout::DEFAULT_LAYOUT; + bool found = false; + for (int i = 0; i < op_call->args()->size(); ++i) { + const uint32_t fb_id = static_cast(op_call->args()->Get(i)); + VkValuePtr value = flatbuffer->values()->Get(fb_id); + if (value->value_type() != vkgraph::GraphTypes::VkTensor) { + continue; + } + vkgraph::VkMemoryLayout layout = + get_resolved_layout(fb_id, flatbuffer, memory_layout_overrides); + if (is_packed_int8_layout(layout)) { + resolved_layout = layout; + found = true; + break; + } + } + + if (!found) { + return; + } + + // Override all args whose resolved layout is still dynamic + for (int i = 0; i < op_call->args()->size(); ++i) { + const uint32_t fb_id = static_cast(op_call->args()->Get(i)); + vkgraph::VkMemoryLayout layout = + get_resolved_layout(fb_id, flatbuffer, memory_layout_overrides); + if (is_dynamic_layout(layout)) { + memory_layout_overrides[fb_id] = resolved_layout; + } + } +} + +void resolve_q8ta_conv2d( + VkGraphPtr flatbuffer, + OpCallPtr op_call, + ComputeGraph* compute_graph, + std::unordered_map& + memory_layout_overrides) { + // q8ta_conv2d args layout: + // 0: input, 1: input_scale, 2: input_zp, 3: weight, 4: weight_sums, + // 5: weight_scales, 6: output_scale, 7: output_zp, 8: bias, + // 9: kernel_size, 10: stride, 11: padding, 12: dilation, 13: groups, + // 14: activation, 15: output + + const uint32_t input_fb_id = static_cast(op_call->args()->Get(0)); + const uint32_t groups_fb_id = static_cast(op_call->args()->Get(13)); + const uint32_t output_fb_id = static_cast(op_call->args()->Get(15)); + + // Only resolve if the input tensor has a dynamic layout + VkTensorPtr input_tensor = + flatbuffer->values()->Get(input_fb_id)->value_as_VkTensor(); + if (!is_dynamic_layout(input_tensor->memory_layout())) { + return; + } + + // Extract groups value + VkValuePtr groups_value = flatbuffer->values()->Get(groups_fb_id); + const int64_t groups = groups_value->value_as_Int()->int_val(); + + // Extract input tensor dimensions + UIntVector input_dims = input_tensor->dims(); + const int64_t input_ndim = input_dims->size(); + const int64_t in_channels = input_dims->Get(input_ndim - 3); + const int64_t in_channels_per_group = in_channels / groups; + + // Extract output tensor dimensions + VkTensorPtr output_tensor = + flatbuffer->values()->Get(output_fb_id)->value_as_VkTensor(); + UIntVector output_dims = output_tensor->dims(); + const int64_t output_ndim = output_dims->size(); + const int64_t H_out = output_dims->Get(output_ndim - 2); + const int64_t W_out = output_dims->Get(output_ndim - 1); + const int64_t spatial_out = H_out * W_out; + + // Replicate the im2col decision logic from Q8taConv2d.cpp + const bool im2col_eligible = in_channels_per_group % 4 == 0; + + bool use_im2col = false; + if (compute_graph->device_is_mali()) { + use_im2col = im2col_eligible; + } else { + use_im2col = im2col_eligible && groups == 1 && + (in_channels_per_group >= 32 || spatial_out <= 4096); + } + + if (use_im2col) { + memory_layout_overrides[input_fb_id] = + vkgraph::VkMemoryLayout::PACKED_INT8_4C; + } else { + memory_layout_overrides[input_fb_id] = + vkgraph::VkMemoryLayout::PACKED_INT8_4C1W; + } +} + +void resolve_q8ta_conv2d_dw( + VkGraphPtr flatbuffer, + OpCallPtr op_call, + std::unordered_map& + memory_layout_overrides) { + const uint32_t input_fb_id = static_cast(op_call->args()->Get(0)); + + // Only override if not already overridden by a previous op + if (memory_layout_overrides.count(input_fb_id) > 0) { + return; + } + + // Only resolve if the input tensor has a dynamic layout + VkTensorPtr input_tensor = + flatbuffer->values()->Get(input_fb_id)->value_as_VkTensor(); + if (!is_dynamic_layout(input_tensor->memory_layout())) { + return; + } + + memory_layout_overrides[input_fb_id] = + vkgraph::VkMemoryLayout::PACKED_INT8_4C1W; +} + +} // namespace + +void resolve_memory_layouts( + const vkgraph::VkGraph* flatbuffer, + ComputeGraph* compute_graph, + std::unordered_map& + memory_layout_overrides) { + // First, handle ops where input memory layout is impactful for performance + for (const auto* op_call : *(flatbuffer->chain())) { + const std::string op_name = op_call->name()->str(); + + if (op_name == "et_vk.q8ta_conv2d.default") { + resolve_q8ta_conv2d( + flatbuffer, op_call, compute_graph, memory_layout_overrides); + } else if (op_name == "et_vk.q8ta_conv2d_dw.default") { + resolve_q8ta_conv2d_dw(flatbuffer, op_call, memory_layout_overrides); + } + } + // Then, try to ensure ops use the same memory layout whenever possible. + for (const auto* op_call : *(flatbuffer->chain())) { + resolve_dynamic_args(flatbuffer, op_call, memory_layout_overrides); + } +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/ResolveLayouts.h b/backends/vulkan/runtime/ResolveLayouts.h new file mode 100644 index 00000000000..332d192d4af --- /dev/null +++ b/backends/vulkan/runtime/ResolveLayouts.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +namespace vkcompute { + +class ComputeGraph; + +void resolve_memory_layouts( + const vkgraph::VkGraph* flatbuffer, + ComputeGraph* compute_graph, + std::unordered_map& + memory_layout_overrides); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7f7afffcf57..d4eeb9b1dd4 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include @@ -32,6 +33,7 @@ #include #include #include +#include #include namespace executorch { @@ -146,8 +148,13 @@ utils::GPUMemoryLayout get_memory_layout( return utils::kPackedInt8_4H4W; case vkgraph::VkMemoryLayout::PACKED_INT8_4W: return utils::kPackedInt8_4W; + case vkgraph::VkMemoryLayout::PACKED_INT8_4C: + return utils::kPackedInt8_4C; case vkgraph::VkMemoryLayout::PACKED_INT8_4C1W: return utils::kPackedInt8_4C1W; + case vkgraph::VkMemoryLayout::PACKED_INT8_CONV2D: + // Fallback for unresolved dynamic layout + return utils::kPackedInt8_4C1W; default: break; } @@ -205,6 +212,8 @@ class GraphBuilder { std::vector loaded_buffers_from_map_; std::vector ref_mapping_; + std::unordered_map + memory_layout_overrides_; public: explicit GraphBuilder( @@ -217,7 +226,13 @@ class GraphBuilder { constant_data_(constant_data), named_data_map_(named_data_map), loaded_buffers_from_map_(), - ref_mapping_() {} + ref_mapping_(), + memory_layout_overrides_() {} + + void resolve_layouts() { + resolve_memory_layouts( + flatbuffer_, compute_graph_, memory_layout_overrides_); + } void resize(uint32_t size) { ref_mapping_.resize(size, INT32_MAX); @@ -235,6 +250,21 @@ class GraphBuilder { return ref_mapping_[fb_id]; } + utils::GPUMemoryLayout get_resolved_memory_layout( + const uint32_t fb_id, + VkTensorPtr tensor_fb, + const std::vector& dims_vector) { + auto it = memory_layout_overrides_.find(fb_id); + if (it != memory_layout_overrides_.end()) { + return get_memory_layout(it->second); + } + + if (tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT) { + return compute_graph_->suggested_memory_layout(dims_vector); + } + return get_memory_layout(tensor_fb->memory_layout()); + } + void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) { const vkapi::ScalarType& dtype = get_scalar_type(tensor_fb->datatype()); utils::StorageType storage_type = @@ -246,9 +276,7 @@ class GraphBuilder { const std::vector dims_vector(dims_fb->cbegin(), dims_fb->cend()); utils::GPUMemoryLayout memory_layout = - tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT - ? compute_graph_->suggested_memory_layout(dims_vector) - : get_memory_layout(tensor_fb->memory_layout()); + get_resolved_memory_layout(fb_id, tensor_fb, dims_vector); ValueRef ref; if (tensor_fb->constant_id() >= 0) { @@ -593,6 +621,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { GraphBuilder builder( compute_graph, flatbuffer_graph, constant_data, named_data_map); + builder.resolve_layouts(); builder.build_graph(); compute_graph->prepare(); diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index 36f9feaa580..92cee15cfe8 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -43,7 +43,9 @@ enum VkMemoryLayout : ubyte { PACKED_INT8_4W4C = 3, PACKED_INT8_4H4W = 4, PACKED_INT8_4W = 5, + PACKED_INT8_4C = 6, PACKED_INT8_4C1W = 8, + PACKED_INT8_CONV2D = 9, DEFAULT_LAYOUT = 255, } diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 845a59a4dff..c53111c2092 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -51,7 +51,9 @@ class VkMemoryLayout(IntEnum): PACKED_INT8_4W4C = 3 PACKED_INT8_4H4W = 4 PACKED_INT8_4W = 5 + PACKED_INT8_4C = 6 PACKED_INT8_4C1W = 8 + PACKED_INT8_CONV2D = 9 DEFAULT_LAYOUT = 255 def __str__(self) -> str: diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index dde9aaac973..2a3c3910c48 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -606,6 +606,7 @@ def node_has_target(node: Any, target: str): VkMemoryLayout.PACKED_INT8_4H4W, VkMemoryLayout.PACKED_INT8_4W, VkMemoryLayout.PACKED_INT8_4C1W, + VkMemoryLayout.PACKED_INT8_CONV2D, } universal_memory_layout_set: Set[VkMemoryLayout] = ( @@ -622,6 +623,7 @@ def node_has_target(node: Any, target: str): VkMemoryLayout.PACKED_INT8_4W4C: 2, VkMemoryLayout.PACKED_INT8_4H4W: 0, VkMemoryLayout.PACKED_INT8_4C1W: 2, + VkMemoryLayout.PACKED_INT8_CONV2D: 2, } @@ -686,6 +688,11 @@ def from_repr( packed_dim=2, packed_dim_block_size=4 if is_buffer else 16, ) + elif memory_layout == VkMemoryLayout.PACKED_INT8_CONV2D: + return cls( + packed_dim=2, + packed_dim_block_size=4, + ) else: raise ValueError(f"Unknown memory layout: {memory_layout}") @@ -742,6 +749,10 @@ def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageEx elif layout == VkMemoryLayout.PACKED_INT8_4H4W: height = (height + 3) // 4 width = (width + 3) // 4 + elif layout == VkMemoryLayout.PACKED_INT8_CONV2D: + # Use conservative extents (same as 4W4C) since this is buffer-only + width = (width + 3) // 4 + channels = (channels + 3) // 4 else: raise RuntimeError(f"Unsupported memory layout {layout}") @@ -1175,8 +1186,15 @@ def filter_invalid_reprs_for_node_list( PACKED_INT8_4W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W}, set()) PACKED_INT8_4C1W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4C1W}, set()) +PACKED_INT8_CONV2D_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_CONV2D}, set()) + PACKED_INT8_CHANNELS_PACKED_BUFFER = TensorRepSet( - {VkMemoryLayout.PACKED_INT8_4W4C, VkMemoryLayout.PACKED_INT8_4C1W}, set() + { + VkMemoryLayout.PACKED_INT8_4W4C, + VkMemoryLayout.PACKED_INT8_4C1W, + VkMemoryLayout.PACKED_INT8_CONV2D, + }, + set(), )