Skip to content

Commit 60b922c

Browse files
committed
[FIX][ARM] Revert v3_2 acl matmul changes
v3_2 acl matmul requires 'any' format for weights to use has_opt_impl() feature. Adaptation in CPU plugin is required and performance is not guaranteed.
1 parent 3efb012 commit 60b922c

File tree

4 files changed

+64
-67
lines changed

4 files changed

+64
-67
lines changed

src/cpu/acl/matmul/acl_matmul.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2023 Arm Ltd. and affiliates
2+
* Copyright 2021-2022 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -31,6 +31,7 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const {
3131
auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
3232

3333
bool is_transA = pd()->amp_.is_transA;
34+
bool is_transB = pd()->amp_.is_transB;
3435
bool use_dst_acc = pd()->amp_.use_dst_acc;
3536

3637
std::lock_guard<std::mutex> _lock {this->mtx};
@@ -42,13 +43,29 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const {
4243
nullptr, &acl_obj.dst_tensor, pd()->amp_.alpha, 0.0f, pd()->amp_.gemm_info);
4344

4445
// Run transpose kernel
45-
if (is_transA) {
46+
if (is_transA && !is_transB) {
4647
acl_obj.src_tensor.allocator()->allocate();
4748
acl_obj.src_acc_tensor.allocator()->import_memory(
4849
const_cast<data_t *>(src_base));
4950
acl_obj.transA.run();
5051
acl_obj.wei_tensor.allocator()->import_memory(
5152
const_cast<data_t *>(wei_base));
53+
} else if (is_transB && !is_transA) {
54+
acl_obj.wei_tensor.allocator()->allocate();
55+
acl_obj.wei_acc_tensor.allocator()->import_memory(
56+
const_cast<data_t *>(wei_base));
57+
acl_obj.transB.run();
58+
acl_obj.src_tensor.allocator()->import_memory(
59+
const_cast<data_t *>(src_base));
60+
} else if (is_transA && is_transB) {
61+
acl_obj.src_tensor.allocator()->allocate();
62+
acl_obj.src_acc_tensor.allocator()->import_memory(
63+
const_cast<data_t *>(src_base));
64+
acl_obj.wei_tensor.allocator()->allocate();
65+
acl_obj.wei_acc_tensor.allocator()->import_memory(
66+
const_cast<data_t *>(wei_base));
67+
acl_obj.transA.run();
68+
acl_obj.transB.run();
5269
} else {
5370
acl_obj.src_tensor.allocator()->import_memory(
5471
const_cast<data_t *>(src_base));
@@ -57,7 +74,7 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const {
5774
}
5875

5976
if (use_dst_acc) {
60-
// Put the result in a new tensor, it will be accumulated to the dst
77+
// Put the result in a new tensor, it will be accumalated to the dst
6178
// during the post ops
6279
acl_obj.dst_tensor.allocator()->allocate();
6380
} else {
@@ -70,6 +87,7 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const {
7087
acl_obj.src_tensor.allocator()->free();
7188
acl_obj.wei_tensor.allocator()->free();
7289
if (is_transA) acl_obj.src_acc_tensor.allocator()->free();
90+
if (is_transB) acl_obj.wei_acc_tensor.allocator()->free();
7391

7492
void *dst = acl_obj.dst_tensor.buffer();
7593
pd()->post_ops.execute(ctx, dst);

src/cpu/acl/matmul/acl_matmul.hpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,20 @@ struct acl_resource_t : public resource_t {
3232

3333
status_t configure(const acl_matmul_conf_t &amp) {
3434
if (!acl_obj_) return status::out_of_memory;
35-
acl_obj_->src_tensor.allocator()->init(amp.src_tensor_info);
36-
acl_obj_->wei_tensor.allocator()->init(amp.wei_tensor_info);
37-
acl_obj_->dst_tensor.allocator()->init(amp.dst_tensor_info);
35+
acl_obj_->src_tensor.allocator()->init(amp.src_info);
36+
acl_obj_->wei_tensor.allocator()->init(amp.wei_info);
37+
acl_obj_->dst_tensor.allocator()->init(amp.dst_info);
3838
// Configure transpose kernel for src, wei or both
3939
if (amp.is_transA) {
4040
acl_obj_->src_acc_tensor.allocator()->init(amp.src_acc_info);
4141
acl_obj_->transA.configure(
4242
&acl_obj_->src_acc_tensor, &acl_obj_->src_tensor);
4343
}
44+
if (amp.is_transB) {
45+
acl_obj_->wei_acc_tensor.allocator()->init(amp.wei_acc_info);
46+
acl_obj_->transB.configure(
47+
&acl_obj_->wei_acc_tensor, &acl_obj_->wei_tensor);
48+
}
4449
// Configure GEMM
4550
acl_obj_->gemm.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor,
4651
nullptr, &acl_obj_->dst_tensor, amp.alpha, 0.0f, amp.gemm_info);
@@ -78,9 +83,7 @@ struct acl_matmul_t : public primitive_t {
7883
&& platform::has_data_type_support(data_type::f16);
7984
bool ok = is_dense_data()
8085
&& utils::one_of(true, is_fp32_ok, is_fp16_ok)
81-
&& !has_zero_dim_memory()
82-
&& weights_md_.format_kind == format_kind::any
83-
&& set_default_formats()
86+
&& !has_zero_dim_memory() && set_default_formats()
8487
&& attr()->has_default_values(
8588
smask_t::oscale | smask_t::post_ops)
8689
&& attr_oscale_ok() && !has_runtime_dims_or_strides();
@@ -95,9 +98,9 @@ struct acl_matmul_t : public primitive_t {
9598
amp_.use_dst_acc = post_ops.has_sum();
9699

97100
// Validate ACL GEMM
98-
ACL_CHECK_VALID(arm_compute::NEGEMM::validate(&amp_.src_tensor_info,
99-
&amp_.wei_tensor_info, nullptr, &amp_.dst_tensor_info,
100-
amp_.alpha, 0.0f, amp_.gemm_info));
101+
ACL_CHECK_VALID(arm_compute::NEGEMM::validate(&amp_.src_info,
102+
&amp_.wei_info, nullptr, &amp_.dst_info, amp_.alpha, 0.0f,
103+
amp_.gemm_info));
101104

102105
return status::success;
103106
}

src/cpu/acl/matmul/acl_matmul_utils.cpp

Lines changed: 20 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
4141
const dim_t src_batch = helper.src_batch();
4242
const dim_t wei_batch = helper.wei_batch();
4343

44-
// We can only broadcast on one of src or wei at once
4544
// ACL supports broadcast for 3D shapes, and 4D shapes
4645
// for e.g when ab in abcd is 1x1
4746
bool batch_ok = IMPLICATION(src_batch > 1, wei_batch == 1)
@@ -54,18 +53,19 @@ status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
5453
bool with_bias = md.bias_desc.format_kind != format_kind::undef;
5554
ACL_CHECK_SUPPORT(with_bias, "ACL does not support bias for matmul");
5655

57-
// The two innermost dimensions can be transposed, but the batch dimensions
58-
// must be the outermost
5956
using namespace format_tag;
6057
auto src_tag = memory_desc_matches_one_of_tag(
6158
src_md, abcd, abdc, abc, acb, ab, ba);
59+
auto wei_tag = memory_desc_matches_one_of_tag(
60+
wei_md, abcd, abdc, abc, acb, ab, ba);
6261
auto dst_tag = memory_desc_matches_one_of_tag(dst_md, abcd, abc, ab, ba);
63-
ACL_CHECK_SUPPORT(utils::one_of(format_tag::undef, src_tag, dst_tag),
62+
ACL_CHECK_SUPPORT(
63+
utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag),
6464
"Format tag is undefined");
6565

66-
// Transpose A (src)
66+
// Transpose A (src) or B (wei)
6767
amp.is_transA = helper.transA() == 'T';
68-
68+
amp.is_transB = helper.transB() == 'T';
6969
auto acl_src_data_t = acl_utils::get_acl_data_t(src_md.data_type);
7070
auto acl_wei_data_t = acl_utils::get_acl_data_t(wei_md.data_type);
7171
auto acl_dst_data_t = acl_utils::get_acl_data_t(dst_md.data_type);
@@ -74,14 +74,21 @@ status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
7474
amp.src_acc_info = arm_compute::TensorInfo(
7575
arm_compute::TensorShape(M, K, 1, src_batch), 1,
7676
acl_src_data_t);
77+
if (amp.is_transB)
78+
amp.wei_acc_info = arm_compute::TensorInfo(
79+
arm_compute::TensorShape(K, N, wei_batch), 1, acl_wei_data_t);
7780

78-
amp.src_tensor_info = arm_compute::TensorInfo(
81+
amp.src_info = arm_compute::TensorInfo(
7982
arm_compute::TensorShape(K, M, 1, src_batch), 1, acl_src_data_t);
80-
amp.wei_tensor_info = arm_compute::TensorInfo(
83+
amp.wei_info = arm_compute::TensorInfo(
8184
arm_compute::TensorShape(N, K, wei_batch), 1, acl_wei_data_t);
82-
amp.dst_tensor_info = arm_compute::TensorInfo(
85+
amp.dst_info = arm_compute::TensorInfo(
8386
arm_compute::TensorShape(N, M, 1, dst_batch), 1, acl_dst_data_t);
8487

88+
bool is_fastmath_enabled = utils::one_of(
89+
attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any);
90+
amp.gemm_info.set_fast_math(is_fastmath_enabled);
91+
8592
// Set alpha (output scaling)
8693
// TODO: Add runtime scales support. Creation time scales will be remove
8794
// in 3.0.
@@ -91,45 +98,10 @@ status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
9198
// Validate ACL transpose
9299
if (amp.is_transA)
93100
ACL_CHECK_VALID(arm_compute::NETranspose::validate(
94-
&amp.src_acc_info, &amp.src_tensor_info));
95-
96-
bool is_fastmath_enabled = utils::one_of(
97-
attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any);
98-
amp.gemm_info.set_fast_math(is_fastmath_enabled);
99-
100-
amp.gemm_info.set_fixed_format(true);
101-
102-
// WeightFormat::ANY tells ACL we can handle any format
103-
amp.gemm_info.set_weight_format(arm_compute::WeightFormat::ANY);
104-
105-
// Get the format that the ACL kernel will expect the weights to be
106-
// in (if a kernel exists). Note that these are referred to as fixed format
107-
// kernels, because they require one specific weights format
108-
arm_compute::WeightFormat expected_weight_format;
109-
ACL_CHECK_VALID(arm_compute::NEGEMM::has_opt_impl(expected_weight_format,
110-
&amp.src_tensor_info, &amp.wei_tensor_info, nullptr,
111-
&amp.dst_tensor_info, amp.alpha, 0.0f, amp.gemm_info));
112-
113-
// Set gemm weights info to the one returned by has_opt_impl
114-
amp.gemm_info.set_weight_format(expected_weight_format);
115-
116-
// has_opt_impl may return a non fast math kernel, even if we requested one
117-
amp.gemm_info.set_fast_math(
118-
arm_compute::is_fixed_format_fast_math(expected_weight_format));
119-
120-
// Logical dimension indices
121-
dim_t innermost_dim = wei_md.ndims - 1;
122-
dim_t N_dim = innermost_dim;
123-
dim_t K_dim = innermost_dim - 1;
124-
125-
// The logical indices of dimensions related to the batch, ordered from
126-
// innermost to outermost
127-
std::vector<dim_t> batch_dims = {};
128-
for (dim_t i = K_dim - 1; i >= 0; --i)
129-
batch_dims.push_back(i);
130-
131-
acl_utils::reorder_to_weight_format(amp.wei_tensor_info, wei_md,
132-
expected_weight_format, K_dim, N_dim, {}, batch_dims);
101+
&amp.src_acc_info, &amp.src_info));
102+
if (amp.is_transB)
103+
ACL_CHECK_VALID(arm_compute::NETranspose::validate(
104+
&amp.wei_acc_info, &amp.wei_info));
133105

134106
return status::success;
135107
}

src/cpu/acl/matmul/acl_matmul_utils.hpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2023 Arm Ltd. and affiliates
2+
* Copyright 2021-2022 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -14,8 +14,8 @@
1414
* limitations under the License.
1515
*******************************************************************************/
1616

17-
#ifndef CPU_ACL_MATMUL_UTILS_HPP
18-
#define CPU_ACL_MATMUL_UTILS_HPP
17+
#ifndef CPU_AARCH64_ACL_MATMUL_UTILS_HPP
18+
#define CPU_AARCH64_ACL_MATMUL_UTILS_HPP
1919

2020
#include "cpu/matmul/cpu_matmul_pd.hpp"
2121

@@ -29,21 +29,25 @@ namespace acl {
2929
struct acl_matmul_obj_t {
3030
arm_compute::NEGEMM gemm;
3131
arm_compute::NETranspose transA;
32+
arm_compute::NETranspose transB;
3233
arm_compute::Tensor src_tensor;
3334
arm_compute::Tensor src_acc_tensor;
3435
arm_compute::Tensor wei_tensor;
36+
arm_compute::Tensor wei_acc_tensor;
3537
arm_compute::Tensor dst_tensor;
3638
};
3739

3840
struct acl_matmul_conf_t {
3941
bool is_transA;
42+
bool is_transB;
4043
// If this is true, the result of the matmul goes into a temporarily
4144
// allocated ACL tensor to be accumulated into the oneDNN dst during postops
4245
bool use_dst_acc;
43-
arm_compute::TensorInfo src_tensor_info;
46+
arm_compute::TensorInfo src_info;
4447
arm_compute::TensorInfo src_acc_info;
45-
arm_compute::TensorInfo wei_tensor_info;
46-
arm_compute::TensorInfo dst_tensor_info;
48+
arm_compute::TensorInfo wei_info;
49+
arm_compute::TensorInfo wei_acc_info;
50+
arm_compute::TensorInfo dst_info;
4751
arm_compute::GEMMInfo gemm_info;
4852
float alpha;
4953
};
@@ -61,4 +65,4 @@ status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
6165
} // namespace impl
6266
} // namespace dnnl
6367

64-
#endif // CPU_ACL_MATMUL_UTILS_HPP
68+
#endif // CPU_AARCH64_ACL_MATMUL_UTILS_HPP

0 commit comments

Comments
 (0)