Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def register_q8ta_conv_pw_op():
def register_q8ta_conv2d_ops():
return OpFeatures(
inputs_storage=[
utils.PACKED_INT8_4C1W_BUFFER, # input
utils.PACKED_INT8_CONV2D_BUFFER, # input
utils.NO_STORAGE, # input_scale (non tensor)
utils.NO_STORAGE, # input_zero_point (non tensor)
utils.NO_STORAGE, # weight (prepacked)
Expand Down
206 changes: 206 additions & 0 deletions backends/vulkan/runtime/ResolveLayouts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/ResolveLayouts.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <string>

namespace vkcompute {

namespace {

using VkGraphPtr = const vkgraph::VkGraph*;
using OpCallPtr = const vkgraph::OperatorCall*;
using VkValuePtr = const vkgraph::VkValue*;
using VkTensorPtr = const vkgraph::VkTensor*;
using UIntVector = const flatbuffers::Vector<uint32_t>*;

bool is_dynamic_layout(const vkgraph::VkMemoryLayout layout) {
return layout == vkgraph::VkMemoryLayout::PACKED_INT8_CONV2D;
}

bool is_packed_int8_layout(vkgraph::VkMemoryLayout layout) {
switch (layout) {
case vkgraph::VkMemoryLayout::PACKED_INT8_4W4C:
case vkgraph::VkMemoryLayout::PACKED_INT8_4H4W:
case vkgraph::VkMemoryLayout::PACKED_INT8_4W:
case vkgraph::VkMemoryLayout::PACKED_INT8_4C:
case vkgraph::VkMemoryLayout::PACKED_INT8_4C1W:
return true;
default:
return false;
}
}

vkgraph::VkMemoryLayout get_resolved_layout(
uint32_t fb_id,
VkGraphPtr flatbuffer,
const std::unordered_map<uint32_t, vkgraph::VkMemoryLayout>&
memory_layout_overrides) {
auto it = memory_layout_overrides.find(fb_id);
if (it != memory_layout_overrides.end()) {
return it->second;
}
VkValuePtr value = flatbuffer->values()->Get(fb_id);
if (value->value_type() != vkgraph::GraphTypes::VkTensor) {
return vkgraph::VkMemoryLayout::DEFAULT_LAYOUT;
}
return value->value_as_VkTensor()->memory_layout();
}

void resolve_dynamic_args(
VkGraphPtr flatbuffer,
OpCallPtr op_call,
std::unordered_map<uint32_t, vkgraph::VkMemoryLayout>&
memory_layout_overrides) {
// Find the first arg tensor with a non-dynamic packed int8 layout
vkgraph::VkMemoryLayout resolved_layout =
vkgraph::VkMemoryLayout::DEFAULT_LAYOUT;
bool found = false;
for (int i = 0; i < op_call->args()->size(); ++i) {
const uint32_t fb_id = static_cast<uint32_t>(op_call->args()->Get(i));
VkValuePtr value = flatbuffer->values()->Get(fb_id);
if (value->value_type() != vkgraph::GraphTypes::VkTensor) {
continue;
}
vkgraph::VkMemoryLayout layout =
get_resolved_layout(fb_id, flatbuffer, memory_layout_overrides);
if (is_packed_int8_layout(layout)) {
resolved_layout = layout;
found = true;
break;
}
}

if (!found) {
return;
}

// Override all args whose resolved layout is still dynamic
for (int i = 0; i < op_call->args()->size(); ++i) {
const uint32_t fb_id = static_cast<uint32_t>(op_call->args()->Get(i));
vkgraph::VkMemoryLayout layout =
get_resolved_layout(fb_id, flatbuffer, memory_layout_overrides);
if (is_dynamic_layout(layout)) {
memory_layout_overrides[fb_id] = resolved_layout;
}
}
}

void resolve_q8ta_conv2d(
VkGraphPtr flatbuffer,
OpCallPtr op_call,
ComputeGraph* compute_graph,
std::unordered_map<uint32_t, vkgraph::VkMemoryLayout>&
memory_layout_overrides) {
// q8ta_conv2d args layout:
// 0: input, 1: input_scale, 2: input_zp, 3: weight, 4: weight_sums,
// 5: weight_scales, 6: output_scale, 7: output_zp, 8: bias,
// 9: kernel_size, 10: stride, 11: padding, 12: dilation, 13: groups,
// 14: activation, 15: output

const uint32_t input_fb_id = static_cast<uint32_t>(op_call->args()->Get(0));
const uint32_t groups_fb_id = static_cast<uint32_t>(op_call->args()->Get(13));
const uint32_t output_fb_id = static_cast<uint32_t>(op_call->args()->Get(15));

// Only resolve if the input tensor has a dynamic layout
VkTensorPtr input_tensor =
flatbuffer->values()->Get(input_fb_id)->value_as_VkTensor();
if (!is_dynamic_layout(input_tensor->memory_layout())) {
return;
}

// Extract groups value
VkValuePtr groups_value = flatbuffer->values()->Get(groups_fb_id);
const int64_t groups = groups_value->value_as_Int()->int_val();

// Extract input tensor dimensions
UIntVector input_dims = input_tensor->dims();
const int64_t input_ndim = input_dims->size();
const int64_t in_channels = input_dims->Get(input_ndim - 3);
const int64_t in_channels_per_group = in_channels / groups;

// Extract output tensor dimensions
VkTensorPtr output_tensor =
flatbuffer->values()->Get(output_fb_id)->value_as_VkTensor();
UIntVector output_dims = output_tensor->dims();
const int64_t output_ndim = output_dims->size();
const int64_t H_out = output_dims->Get(output_ndim - 2);
const int64_t W_out = output_dims->Get(output_ndim - 1);
const int64_t spatial_out = H_out * W_out;

// Replicate the im2col decision logic from Q8taConv2d.cpp
const bool im2col_eligible = in_channels_per_group % 4 == 0;

bool use_im2col = false;
if (compute_graph->device_is_mali()) {
use_im2col = im2col_eligible;
} else {
use_im2col = im2col_eligible && groups == 1 &&
(in_channels_per_group >= 32 || spatial_out <= 4096);
}

if (use_im2col) {
memory_layout_overrides[input_fb_id] =
vkgraph::VkMemoryLayout::PACKED_INT8_4C;
} else {
memory_layout_overrides[input_fb_id] =
vkgraph::VkMemoryLayout::PACKED_INT8_4C1W;
}
}

void resolve_q8ta_conv2d_dw(
VkGraphPtr flatbuffer,
OpCallPtr op_call,
std::unordered_map<uint32_t, vkgraph::VkMemoryLayout>&
memory_layout_overrides) {
const uint32_t input_fb_id = static_cast<uint32_t>(op_call->args()->Get(0));

// Only override if not already overridden by a previous op
if (memory_layout_overrides.count(input_fb_id) > 0) {
return;
}

// Only resolve if the input tensor has a dynamic layout
VkTensorPtr input_tensor =
flatbuffer->values()->Get(input_fb_id)->value_as_VkTensor();
if (!is_dynamic_layout(input_tensor->memory_layout())) {
return;
}

memory_layout_overrides[input_fb_id] =
vkgraph::VkMemoryLayout::PACKED_INT8_4C1W;
}

} // namespace

void resolve_memory_layouts(
const vkgraph::VkGraph* flatbuffer,
ComputeGraph* compute_graph,
std::unordered_map<uint32_t, vkgraph::VkMemoryLayout>&
memory_layout_overrides) {
// First, handle ops where input memory layout is impactful for performance
for (const auto* op_call : *(flatbuffer->chain())) {
const std::string op_name = op_call->name()->str();

if (op_name == "et_vk.q8ta_conv2d.default") {
resolve_q8ta_conv2d(
flatbuffer, op_call, compute_graph, memory_layout_overrides);
} else if (op_name == "et_vk.q8ta_conv2d_dw.default") {
resolve_q8ta_conv2d_dw(flatbuffer, op_call, memory_layout_overrides);
}
}
// Then, try to ensure ops use the same memory layout whenever possible.
for (const auto* op_call : *(flatbuffer->chain())) {
resolve_dynamic_args(flatbuffer, op_call, memory_layout_overrides);
}
}

} // namespace vkcompute
26 changes: 26 additions & 0 deletions backends/vulkan/runtime/ResolveLayouts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>
#include <unordered_map>

#include <executorch/backends/vulkan/serialization/schema_generated.h>

namespace vkcompute {

class ComputeGraph;

void resolve_memory_layouts(
const vkgraph::VkGraph* flatbuffer,
ComputeGraph* compute_graph,
std::unordered_map<uint32_t, vkgraph::VkMemoryLayout>&
memory_layout_overrides);

} // namespace vkcompute
37 changes: 33 additions & 4 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/ResolveLayouts.h>
#include <executorch/backends/vulkan/runtime/VulkanDelegateHeader.h>
#include <executorch/backends/vulkan/serialization/schema_generated.h>

Expand All @@ -32,6 +33,7 @@
#include <cstring>
#include <memory>
#include <type_traits>
#include <unordered_map>
#include <vector>

namespace executorch {
Expand Down Expand Up @@ -146,8 +148,13 @@ utils::GPUMemoryLayout get_memory_layout(
return utils::kPackedInt8_4H4W;
case vkgraph::VkMemoryLayout::PACKED_INT8_4W:
return utils::kPackedInt8_4W;
case vkgraph::VkMemoryLayout::PACKED_INT8_4C:
return utils::kPackedInt8_4C;
case vkgraph::VkMemoryLayout::PACKED_INT8_4C1W:
return utils::kPackedInt8_4C1W;
case vkgraph::VkMemoryLayout::PACKED_INT8_CONV2D:
// Fallback for unresolved dynamic layout
return utils::kPackedInt8_4C1W;
default:
break;
}
Expand Down Expand Up @@ -205,6 +212,8 @@ class GraphBuilder {
std::vector<FreeableBuffer> loaded_buffers_from_map_;

std::vector<ValueRef> ref_mapping_;
std::unordered_map<uint32_t, vkgraph::VkMemoryLayout>
memory_layout_overrides_;

public:
explicit GraphBuilder(
Expand All @@ -217,7 +226,13 @@ class GraphBuilder {
constant_data_(constant_data),
named_data_map_(named_data_map),
loaded_buffers_from_map_(),
ref_mapping_() {}
ref_mapping_(),
memory_layout_overrides_() {}

void resolve_layouts() {
resolve_memory_layouts(
flatbuffer_, compute_graph_, memory_layout_overrides_);
}

void resize(uint32_t size) {
ref_mapping_.resize(size, INT32_MAX);
Expand All @@ -235,6 +250,21 @@ class GraphBuilder {
return ref_mapping_[fb_id];
}

utils::GPUMemoryLayout get_resolved_memory_layout(
const uint32_t fb_id,
VkTensorPtr tensor_fb,
const std::vector<int64_t>& dims_vector) {
auto it = memory_layout_overrides_.find(fb_id);
if (it != memory_layout_overrides_.end()) {
return get_memory_layout(it->second);
}

if (tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT) {
return compute_graph_->suggested_memory_layout(dims_vector);
}
return get_memory_layout(tensor_fb->memory_layout());
}

void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) {
const vkapi::ScalarType& dtype = get_scalar_type(tensor_fb->datatype());
utils::StorageType storage_type =
Expand All @@ -246,9 +276,7 @@ class GraphBuilder {
const std::vector<int64_t> dims_vector(dims_fb->cbegin(), dims_fb->cend());

utils::GPUMemoryLayout memory_layout =
tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT
? compute_graph_->suggested_memory_layout(dims_vector)
: get_memory_layout(tensor_fb->memory_layout());
get_resolved_memory_layout(fb_id, tensor_fb, dims_vector);

ValueRef ref;
if (tensor_fb->constant_id() >= 0) {
Expand Down Expand Up @@ -593,6 +621,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
GraphBuilder builder(
compute_graph, flatbuffer_graph, constant_data, named_data_map);

builder.resolve_layouts();
builder.build_graph();

compute_graph->prepare();
Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,11 @@ class ComputeGraph final {
inline bool device_is_adreno() {
return context_->adapter_ptr()->device_type() == vkapi::DeviceType::ADRENO;
}

inline bool device_is_mali() {
return context_->adapter_ptr()->device_type() == vkapi::DeviceType::MALI;
}

const std::string& device_name() {
return context()->adapter_ptr()->device_name();
}
Expand Down
17 changes: 12 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ layout(push_constant) uniform restrict Block {
int input_zp;
float output_inv_scale;
int output_zp;
int K4_per_group;
int OC4_per_group;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "apply_bias", "1")}
${layout_declare_spec_const(C, "int", "activation_type", "0")}
${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")}

// Layout specialization constants
${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")}
Expand Down Expand Up @@ -124,12 +125,18 @@ void main() {
}
}

// Compute initial input tile index
// Input has same spatial layout, channel dimension iterates from 0
int input_idx = oh * inp_h_stride + ow_block_idx * inp_w_stride;
// Compute group index from output channel block
const int group_idx = oc_block_idx / OC4_per_group;

// Compute initial input tile index with group offset
// For grouped im2col, each group's K range starts at group_idx * K4_per_group
// For non-grouped (groups=1), group_idx is always 0 so offset is 0
int input_idx = oh * inp_h_stride
+ ow_block_idx * inp_w_stride
+ group_idx * K4_per_group;

// Main accumulation loop over K dimension
for (int k4 = 0; k4 < conv2d_params_K4_per_group; k4++) {
for (int k4 = 0; k4 < K4_per_group; k4++) {
// Load packed int8 input tile (TILE_M4=1, TILE_K4=1)
// Each int contains 4 packed int8s (one per width position in the tile)
ivec4 int8_input_tile = t_packed_int8_input[input_idx];
Expand Down
Loading
Loading