@@ -41,7 +41,6 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md,
4141 const dim_t src_batch = helper.src_batch ();
4242 const dim_t wei_batch = helper.wei_batch ();
4343
44- // We can only broadcast on one of src or wei at once
4544 // ACL supports broadcast for 3D shapes, and 4D shapes
4645 // for e.g when ab in abcd is 1x1
4746 bool batch_ok = IMPLICATION (src_batch > 1 , wei_batch == 1 )
@@ -54,18 +53,19 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md,
5453 bool with_bias = md.bias_desc .format_kind != format_kind::undef;
5554 ACL_CHECK_SUPPORT (with_bias, " ACL does not support bias for matmul" );
5655
57- // The two innermost dimensions can be transposed, but the batch dimensions
58- // must be the outermost
5956 using namespace format_tag ;
6057 auto src_tag = memory_desc_matches_one_of_tag (
6158 src_md, abcd, abdc, abc, acb, ab, ba);
59+ auto wei_tag = memory_desc_matches_one_of_tag (
60+ wei_md, abcd, abdc, abc, acb, ab, ba);
6261 auto dst_tag = memory_desc_matches_one_of_tag (dst_md, abcd, abc, ab, ba);
63- ACL_CHECK_SUPPORT (utils::one_of (format_tag::undef, src_tag, dst_tag),
62+ ACL_CHECK_SUPPORT (
63+ utils::one_of (format_tag::undef, src_tag, wei_tag, dst_tag),
6464 " Format tag is undefined" );
6565
66- // Transpose A (src)
66+ // Transpose A (src) or B (wei)
6767 amp.is_transA = helper.transA () == ' T' ;
68-
68+ amp. is_transB = helper. transB () == ' T ' ;
6969 auto acl_src_data_t = acl_utils::get_acl_data_t (src_md.data_type );
7070 auto acl_wei_data_t = acl_utils::get_acl_data_t (wei_md.data_type );
7171 auto acl_dst_data_t = acl_utils::get_acl_data_t (dst_md.data_type );
@@ -74,14 +74,21 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md,
7474 amp.src_acc_info = arm_compute::TensorInfo (
7575 arm_compute::TensorShape (M, K, 1 , src_batch), 1 ,
7676 acl_src_data_t );
77+ if (amp.is_transB )
78+ amp.wei_acc_info = arm_compute::TensorInfo (
79+ arm_compute::TensorShape (K, N, wei_batch), 1 , acl_wei_data_t );
7780
78- amp.src_tensor_info = arm_compute::TensorInfo (
81+ amp.src_info = arm_compute::TensorInfo (
7982 arm_compute::TensorShape (K, M, 1 , src_batch), 1 , acl_src_data_t );
80- amp.wei_tensor_info = arm_compute::TensorInfo (
83+ amp.wei_info = arm_compute::TensorInfo (
8184 arm_compute::TensorShape (N, K, wei_batch), 1 , acl_wei_data_t );
82- amp.dst_tensor_info = arm_compute::TensorInfo (
85+ amp.dst_info = arm_compute::TensorInfo (
8386 arm_compute::TensorShape (N, M, 1 , dst_batch), 1 , acl_dst_data_t );
8487
88+ bool is_fastmath_enabled = utils::one_of (
89+ attr.fpmath_mode_ , fpmath_mode::bf16 , fpmath_mode::any);
90+ amp.gemm_info .set_fast_math (is_fastmath_enabled);
91+
8592 // Set alpha (output scaling)
8693 // TODO: Add runtime scales support. Creation time scales will be remove
8794 // in 3.0.
@@ -91,45 +98,10 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md,
9198 // Validate ACL transpose
9299 if (amp.is_transA )
93100 ACL_CHECK_VALID (arm_compute::NETranspose::validate (
94- &.src_acc_info , &.src_tensor_info ));
95-
96- bool is_fastmath_enabled = utils::one_of (
97- attr.fpmath_mode_ , fpmath_mode::bf16 , fpmath_mode::any);
98- amp.gemm_info .set_fast_math (is_fastmath_enabled);
99-
100- amp.gemm_info .set_fixed_format (true );
101-
102- // WeightFormat::ANY tells ACL we can handle any format
103- amp.gemm_info .set_weight_format (arm_compute::WeightFormat::ANY);
104-
105- // Get the format that the ACL kernel will expect the weights to be
106- // in (if a kernel exists). Note that these are referred to as fixed format
107- // kernels, because they require one specific weights format
108- arm_compute::WeightFormat expected_weight_format;
109- ACL_CHECK_VALID (arm_compute::NEGEMM::has_opt_impl (expected_weight_format,
110- &.src_tensor_info , &.wei_tensor_info , nullptr ,
111- &.dst_tensor_info , amp.alpha , 0 .0f , amp.gemm_info ));
112-
113- // Set gemm weights info to the one returned by has_opt_impl
114- amp.gemm_info .set_weight_format (expected_weight_format);
115-
116- // has_opt_impl may return a non fast math kernel, even if we requested one
117- amp.gemm_info .set_fast_math (
118- arm_compute::is_fixed_format_fast_math (expected_weight_format));
119-
120- // Logical dimension indices
121- dim_t innermost_dim = wei_md.ndims - 1 ;
122- dim_t N_dim = innermost_dim;
123- dim_t K_dim = innermost_dim - 1 ;
124-
125- // The logical indices of dimensions related to the batch, ordered from
126- // innermost to outermost
127- std::vector<dim_t > batch_dims = {};
128- for (dim_t i = K_dim - 1 ; i >= 0 ; --i)
129- batch_dims.push_back (i);
130-
131- acl_utils::reorder_to_weight_format (amp.wei_tensor_info , wei_md,
132- expected_weight_format, K_dim, N_dim, {}, batch_dims);
101+ &.src_acc_info , &.src_info ));
102+ if (amp.is_transB )
103+ ACL_CHECK_VALID (arm_compute::NETranspose::validate (
104+ &.wei_acc_info , &.wei_info ));
133105
134106 return status::success;
135107}
0 commit comments