Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions backends/cortex_m/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,22 @@ endif()

# Cortex-M ops kernel sources
set(_cortex_m_kernels__srcs
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_maximum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_pad.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_avg_pool2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_batch_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_depthwise_conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_max_pool2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_max_pool2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_transpose_conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_maximum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_transpose.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_pad.cpp
)

# Generate C++ bindings to register kernels into Executorch
Expand Down
146 changes: 146 additions & 0 deletions backends/cortex_m/ops/op_quantized_batch_matmul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "cortex_m_ops_common.h"

extern "C" {
#include "arm_nnfunctions.h"
}

namespace cortex_m {
namespace native {

using KernelRuntimeContext = torch::executor::KernelRuntimeContext;

namespace {

bool validate_batch_matmul_arguments(
KernelRuntimeContext& context,
const Tensor& lhs,
const Tensor& rhs_transposed,
const Tensor& out) {
if (lhs.scalar_type() != ScalarType::Char ||
rhs_transposed.scalar_type() != ScalarType::Char ||
out.scalar_type() != ScalarType::Char) {
ET_LOG(Error, "quantized_batch_matmul: all tensors must be int8");
context.fail(Error::InvalidArgument);
return false;
}

if (lhs.dim() != 3 || rhs_transposed.dim() != 3 || out.dim() != 3) {
ET_LOG(Error, "quantized_batch_matmul: all tensors must be 3-D");
context.fail(Error::InvalidArgument);
return false;
}

if (lhs.size(0) != rhs_transposed.size(0)) {
ET_LOG(Error, "quantized_batch_matmul: batch dims must match");
context.fail(Error::InvalidArgument);
return false;
}

if (lhs.size(2) != rhs_transposed.size(2)) {
ET_LOG(Error, "quantized_batch_matmul: inner dims must match");
context.fail(Error::InvalidArgument);
return false;
}

if (out.size(0) != lhs.size(0) || out.size(1) != lhs.size(1) ||
out.size(2) != rhs_transposed.size(1)) {
ET_LOG(Error, "quantized_batch_matmul: output shape mismatch");
context.fail(Error::InvalidArgument);
return false;
}

return true;
}

} // namespace

Tensor& quantized_batch_matmul_out(
KernelRuntimeContext& context,
const Tensor& lhs,
int64_t lhs_offset,
const Tensor& rhs_transposed,
int64_t rhs_offset,
int64_t output_offset,
int64_t output_multiplier,
int64_t output_shift,
Tensor& out) {
if (!validate_batch_matmul_arguments(context, lhs, rhs_transposed, out)) {
return out;
}

const int32_t batch = static_cast<int32_t>(lhs.size(0));
const int32_t lhs_rows = static_cast<int32_t>(lhs.size(1));
const int32_t inner = static_cast<int32_t>(lhs.size(2));
const int32_t rhs_cols = static_cast<int32_t>(rhs_transposed.size(1));

const cmsis_nn_dims lhs_dims = {1, batch, lhs_rows, inner};
const cmsis_nn_dims rhs_dims = {1, batch, rhs_cols, inner};
const cmsis_nn_dims out_dims = {1, batch, lhs_rows, rhs_cols};

const cmsis_nn_bmm_params bmm_params = {
/* adj_x */ false,
/* adj_y */ false,
/* fc_params */
{static_cast<int32_t>(lhs_offset),
static_cast<int32_t>(rhs_offset),
static_cast<int32_t>(output_offset),
/* activation */
{std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max()}}};

cmsis_nn_per_tensor_quant_params quant_params;
quant_params.multiplier = static_cast<int32_t>(output_multiplier);
quant_params.shift = static_cast<int32_t>(output_shift);

const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&out_dims);

cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.size = 0;

if (buf_size > 0) {
auto buffer_or_error = context.allocate_temp(buf_size);
if (!buffer_or_error.ok()) {
ET_LOG(
Error,
"quantized_batch_matmul: failed to allocate scratch buffer (%d bytes)",
buf_size);
context.fail(buffer_or_error.error());
return out;
}
ctx.buf = buffer_or_error.get();
ctx.size = buf_size;
}

const arm_cmsis_nn_status status = arm_batch_matmul_s8(
&ctx,
&bmm_params,
&quant_params,
&lhs_dims,
lhs.const_data_ptr<int8_t>(),
&rhs_dims,
rhs_transposed.const_data_ptr<int8_t>(),
&out_dims,
out.mutable_data_ptr<int8_t>());

if (status != ARM_CMSIS_NN_SUCCESS) {
ET_LOG(
Error,
"quantized_batch_matmul: arm_batch_matmul_s8 failed with status [%d]",
status);
context.fail(Error::Internal);
}

return out;
}

} // namespace native
} // namespace cortex_m
53 changes: 53 additions & 0 deletions backends/cortex_m/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,59 @@ def quantized_mul_impl(
return result


# ===================================================================
# QUANTIZED BATCH MATMUL OPERATION DEFINITION
# ===================================================================
lib.define(
"quantized_batch_matmul("
"Tensor lhs, int lhs_zero_point, "
"Tensor rhs_transposed, int rhs_zero_point, "
"int output_zero_point, int output_multiplier, int output_shift) -> Tensor"
)
lib.define(
"quantized_batch_matmul.out("
"Tensor lhs, int lhs_zero_point, "
"Tensor rhs_transposed, int rhs_zero_point, "
"int output_zero_point, int output_multiplier, int output_shift, "
"*, Tensor(a!) out) -> Tensor(a!)"
)


@register_fake("cortex_m::quantized_batch_matmul")
def quantized_batch_matmul_meta(
lhs: torch.Tensor,
lhs_zero_point: int,
rhs_transposed: torch.Tensor,
rhs_zero_point: int,
output_zero_point: int,
output_multiplier: int,
output_shift: int,
) -> torch.Tensor:
batch, lhs_rows, inner = lhs.shape
batch_rhs, rhs_cols, inner_rhs = rhs_transposed.shape
assert batch == batch_rhs and inner == inner_rhs
return torch.empty((batch, lhs_rows, rhs_cols), dtype=torch.int8, device=lhs.device)


@impl(lib, "quantized_batch_matmul", "CompositeExplicitAutograd")
def quantized_batch_matmul_impl(
lhs: torch.Tensor,
lhs_zero_point: int,
rhs_transposed: torch.Tensor,
rhs_zero_point: int,
output_zero_point: int,
output_multiplier: int,
output_shift: int,
) -> torch.Tensor:
# Offsets are negated zero points (CMSIS-NN convention)
lhs_fp = lhs.to(torch.float32) + float(lhs_zero_point)
rhs_t_fp = rhs_transposed.to(torch.float32) + float(rhs_zero_point)
rhs_fp = rhs_t_fp.permute(0, 2, 1)
acc = torch.bmm(lhs_fp, rhs_fp).to(torch.int32)
result = requantize_cmsis(acc, output_multiplier, output_shift)
return torch.clamp(result + output_zero_point, -128, 127).to(torch.int8)


# ===================================================================
# MINIMUM/MAXIMUM OPERATION DEFINITIONS
# ===================================================================
Expand Down
6 changes: 6 additions & 0 deletions backends/cortex_m/ops/operators.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,9 @@
kernels:
- arg_meta: null
kernel_name: cortex_m::quantized_max_pool2d_out

- func: cortex_m::quantized_batch_matmul.out(Tensor lhs, int lhs_zero_point, Tensor rhs_transposed, int rhs_zero_point, int output_zero_point, int output_multiplier, int output_shift, *, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
kernel_name: cortex_m::quantized_batch_matmul_out
49 changes: 49 additions & 0 deletions backends/cortex_m/passes/convert_to_cortex_m_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from executorch.backends.transforms.utils import (
create_constant_placeholder,
get_param_tensor,
is_param_node,
)

from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
Expand Down Expand Up @@ -372,6 +373,52 @@ def _get_transpose_conv2d_replacement(self, node) -> tuple:
)
return exir_ops.edge.cortex_m.quantized_transpose_conv2d.default, new_args

def _get_bmm_replacement(self, node):
lhs_scale = node.meta["input_qparams"][0].scale
lhs_zp = node.meta["input_qparams"][0].zp
rhs_scale = node.meta["input_qparams"][1].scale
rhs_zp = node.meta["input_qparams"][1].zp
output_scale = node.meta["output_qparams"][0].scale
output_zp = node.meta["output_qparams"][0].zp

output_mult, output_shift = quantize_multiplier_aot(
(lhs_scale * rhs_scale) / output_scale
)

lhs_node = node.args[0]
rhs_node = node.args[1]

is_constant_rhs = is_param_node(self.exported_program, rhs_node)
if is_constant_rhs:
rhs_tensor = get_param_tensor(self.exported_program, rhs_node)
rhs_transposed_tensor = rhs_tensor.permute(0, 2, 1).contiguous()
with node.graph.inserting_after(rhs_node):
rhs_transposed = create_constant_placeholder(
self.exported_program,
node.graph,
node.name + "_rhs_transposed",
InputKind.PARAMETER,
rhs_transposed_tensor,
)
else:
with node.graph.inserting_before(node):
rhs_transposed = node.graph.create_node(
"call_function",
target=exir_ops.edge.cortex_m.transpose.default,
args=(rhs_node, [0, 2, 1]),
)

args = (
lhs_node,
-lhs_zp,
rhs_transposed,
-rhs_zp,
output_zp,
output_mult,
output_shift,
)
return exir_ops.edge.cortex_m.quantized_batch_matmul.default, args

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False
for node in graph_module.graph.nodes:
Expand All @@ -393,6 +440,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
op, args = self._get_transpose_conv2d_replacement(node)
else:
op, args = self._get_convolution_replacement(node)
case exir_ops.edge.aten.bmm.default:
op, args = self._get_bmm_replacement(node)
case _:
continue

Expand Down
27 changes: 27 additions & 0 deletions backends/cortex_m/quantizer/pattern_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,33 @@ def check_quantization_config(
return is_int8 and is_per_tensor


class CortexMBmmCheck(PatternCheck):

@classmethod
def check_pattern(cls, pattern):
for node in pattern:
if len(node.all_input_nodes) == 2:
t1 = get_first_fake_tensor(node.all_input_nodes[0])
t2 = get_first_fake_tensor(node.all_input_nodes[1])
if t1.dim() != 3 or t2.dim() != 3:
return False
if t1.shape[0] != t2.shape[0]:
return False
if t1.shape[2] != t2.shape[1]:
return False
return True

@classmethod
def check_quantization_config(
cls, pattern: list[Node], quantization_config: CortexMQuantizationConfig
):
is_per_tensor = PatternCheck.is_per_tensor(
quantization_config.get_input_act_qspec()
) and PatternCheck.is_per_tensor(quantization_config.get_output_act_qspec())
is_int8 = cls.is_int8_activations(quantization_config)
return is_per_tensor and is_int8


class CortexMMaxPool2DCheck(PatternCheck):
@classmethod
def _pool_arg_as_bool(cls, node: Node, index: int, default: bool) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions backends/cortex_m/quantizer/quantizer_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from executorch.backends.cortex_m.quantizer.pattern_checkers import (
CortexMAddMulCheck,
CortexMAvgPool2DCheck,
CortexMBmmCheck,
CortexMConv2DCheck,
CortexMConvTranspose2DCheck,
CortexMLinearCheck,
Expand Down Expand Up @@ -118,11 +119,16 @@
(torch.ops.aten.max_pool2d_with_indices.default,): CortexMMaxPool2DCheck,
}

BMM_OP_PATTERNS = {
(torch.ops.aten.bmm.default,): CortexMBmmCheck,
}

CORTEX_M_QUANTIZER_SUPPORT_DICT = (
BINARY_OP_PATTERNS
| LINEAR_OP_PATTERNS
| CONV_OP_PATTERNS
| SOFTMAX_OP_PATTERNS
| CONV_TRANSPOSE_OP_PATTERNS
| POOL_OP_PATTERNS
| BMM_OP_PATTERNS
)
Loading
Loading