@@ -552,6 +552,192 @@ status_t acl_init_conf_dw(acl_conv_conf_t &acp, memory_desc_t &src_md,
552552 return status::success;
553553}
554554
555+ status_t acl_init_conf_wino (acl_conv_conf_t &acp, memory_desc_t &src_md,
556+ memory_desc_t &weights_md, memory_desc_t &dst_md,
557+ memory_desc_t &bias_md, const convolution_desc_t &cd,
558+ const primitive_attr_t &attr) {
559+ const memory_desc_wrapper src_d (&src_md);
560+ const memory_desc_wrapper wei_d (&weights_md);
561+ const memory_desc_wrapper dst_d (&dst_md);
562+ const memory_desc_wrapper bia_d (&bias_md);
563+
564+ acp.fast_math
565+ = one_of (attr.fpmath_mode_ , fpmath_mode::bf16 , fpmath_mode::any);
566+
567+ // Compute Library currently supports forward propagation only
568+ const prop_kind_t prop_kind = cd.prop_kind ;
569+ const bool is_fwd = (prop_kind == dnnl_forward_training)
570+ || (prop_kind == dnnl_forward_inference);
571+ if (!is_fwd) return status::unimplemented;
572+
573+ const int with_groups = wei_d.ndims () == src_d.ndims () + 1 ;
574+ const int ndims = src_d.ndims ();
575+ const bool is_1d = ndims == 3 ;
576+ const bool is_3d = ndims == 5 ;
577+ bool is_nspc;
578+
579+ // Compute Library unsupported shape scenarios
580+ if (one_of (true , is_3d, is_1d, with_groups)) {
581+ return status::unimplemented;
582+ }
583+
584+ // batch size
585+ const int mb = src_d.dims ()[0 ];
586+
587+ // src/input channels, height, width
588+ const int ic = src_d.dims ()[1 ];
589+ const int ih = src_d.dims ()[ndims - 2 ];
590+ const int iw = src_d.dims ()[ndims - 1 ];
591+
592+ // dst/output channels, height, width
593+ const int oc = dst_d.dims ()[1 ];
594+ const int oh = dst_d.dims ()[ndims - 2 ];
595+ const int ow = dst_d.dims ()[ndims - 1 ];
596+
597+ // weights height and width
598+ const int kh = wei_d.dims ()[with_groups + ndims - 2 ];
599+ const int kw = wei_d.dims ()[with_groups + ndims - 1 ];
600+
601+ // height and width strides
602+ const int stride_h = cd.strides [ndims - 4 ];
603+ const int stride_w = cd.strides [ndims - 3 ];
604+
605+ // height and width dilations
606+ int dilate_h = cd.dilates [ndims - 4 ];
607+ int dilate_w = cd.dilates [ndims - 3 ];
608+ // oneDNN dilations: dk = 1 + (k_size - 1) * (dilate_size + 1)
609+ // Compute Library dilations: dk = dilate_size * (k_size - 1) + 1
610+ // thus acl_dilation = oneDNN_dilation + 1
611+ dilate_h += 1 ;
612+ dilate_w += 1 ;
613+
614+ acp.dilation_info = arm_compute::Size2D (dilate_w, dilate_h);
615+
616+ // left, right, top, bottom padding
617+ const int l_pad = cd.padding [0 ][1 ];
618+ const int t_pad = cd.padding [0 ][0 ];
619+ // Compute Library assumes the padding to be \geq 0, and r(b)_pad may be
620+ // equal to -1 in oneDNN for some cases, when the very right (bottom)
621+ // spatial elements of the input tensor are not used in the convolution.
622+ // On the other hand l(t)_pad are guaranteed to be non-negative.
623+ const int r_pad = std::max (static_cast <int >(cd.padding [1 ][1 ]), 0 );
624+ const int b_pad = std::max (static_cast <int >(cd.padding [1 ][0 ]), 0 );
625+
626+ acp.padstride_info = arm_compute::PadStrideInfo (stride_w, stride_h,
627+ static_cast <unsigned int >(l_pad), static_cast <unsigned int >(r_pad),
628+ static_cast <unsigned int >(t_pad), static_cast <unsigned int >(b_pad),
629+ arm_compute::DimensionRoundingType::FLOOR);
630+
631+ acp.with_bias = cd.bias_desc .format_kind != format_kind::undef;
632+
633+ auto set_or_check_tags = [&](format_tag_t desired_src_tag,
634+ format_tag_t desired_dst_tag) -> status_t {
635+ using namespace format_tag ;
636+ auto src_tag = any, dst_tag = any;
637+
638+ if (src_d.format_kind () == format_kind::any) {
639+ CHECK (memory_desc_init_by_tag (src_md, desired_src_tag));
640+ src_tag = desired_src_tag;
641+ } else {
642+ src_tag = memory_desc_matches_one_of_tag (src_md, nhwc, nchw);
643+ }
644+
645+ if (dst_d.format_kind () == format_kind::any) {
646+ CHECK (memory_desc_init_by_tag (dst_md, desired_dst_tag));
647+ dst_tag = desired_dst_tag;
648+ } else {
649+ dst_tag = memory_desc_matches_one_of_tag (dst_md, nhwc, nchw);
650+ }
651+
652+ if (acp.with_bias && bias_md.format_kind == format_kind::any)
653+ CHECK (memory_desc_init_by_tag (bias_md, x));
654+
655+ is_nspc = utils::one_of (src_tag, nhwc);
656+
657+ memory_desc_t want_wei_md = weights_md;
658+ auto wei_tag = is_nspc ? ohwi : oihw;
659+ CHECK (memory_desc_init_by_tag (want_wei_md, wei_tag));
660+
661+ // Compute Library does not support mismatching layouts
662+ if ((src_tag != wei_tag) || (src_tag != dst_tag))
663+ return status::unimplemented;
664+
665+ if (weights_md.format_kind == format_kind::any) {
666+ weights_md = want_wei_md;
667+ }
668+ return (want_wei_md == weights_md) ? status::success
669+ : status::unimplemented;
670+ };
671+
672+ auto default_dat_tag = format_tag::nhwc;
673+ if (set_or_check_tags (default_dat_tag, default_dat_tag) != status::success)
674+ return status::unimplemented;
675+
676+ const auto acl_layout = is_nspc ? arm_compute::DataLayout::NHWC
677+ : arm_compute::DataLayout::NCHW;
678+
679+ // For convolutions, int8 datatypes imply quantized types in ACL
680+ const auto is_int8 = utils::one_of (src_d.data_type (), s8, u8 )
681+ && wei_d.data_type () == s8;
682+
683+ auto acl_src_data_t
684+ = acl_utils::get_acl_data_t (src_d.data_type (), is_int8);
685+ auto acl_wei_data_t
686+ = acl_utils::get_acl_data_t (wei_d.data_type (), is_int8);
687+ auto acl_dst_data_t
688+ = acl_utils::get_acl_data_t (dst_d.data_type (), is_int8);
689+ auto acl_bia_data_t
690+ = acl_utils::get_acl_data_t (bia_d.data_type (), is_int8);
691+
692+ if (acl_bia_data_t == arm_compute::DataType::UNKNOWN)
693+ acl_bia_data_t = arm_compute::DataType::F32;
694+
695+ // clang-format off
696+ acp.src_tensor_info = arm_compute::TensorInfo (
697+ is_nspc ? arm_compute::TensorShape (ic, iw, ih, mb) :
698+ arm_compute::TensorShape (iw, ih, ic, mb),
699+ 1 ,
700+ acl_src_data_t ,
701+ acl_layout);
702+
703+ acp.wei_tensor_info = arm_compute::TensorInfo (
704+ is_nspc ? arm_compute::TensorShape (ic, kw, kh, oc) :
705+ arm_compute::TensorShape (kw, kh, ic, oc),
706+ 1 ,
707+ acl_wei_data_t ,
708+ acl_layout);
709+
710+ acp.dst_tensor_info = arm_compute::TensorInfo (
711+ is_nspc ? arm_compute::TensorShape (oc, ow, oh, mb) :
712+ arm_compute::TensorShape (ow, oh, oc, mb),
713+ 1 ,
714+ acl_dst_data_t ,
715+ acl_layout);
716+
717+ acp.bia_tensor_info = arm_compute::TensorInfo (
718+ acp.with_bias ? arm_compute::TensorShape (oc)
719+ : arm_compute::TensorShape (),
720+ 1 ,
721+ acl_bia_data_t ,
722+ acl_layout);
723+ // clang-format on
724+
725+ // Add quantization info to tensors
726+ if (is_int8) {
727+ // TODO: Add runtime scales support. Creation time scales will be remove
728+ // in 3.0.
729+ // const float *scales = attr.output_scales_.scales_;
730+ // acp.src_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0));
731+ // acp.bia_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0));
732+ // acp.wei_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0));
733+ // acp.dst_info.set_quantization_info(
734+ // arm_compute::QuantizationInfo(1.0f / scales[0], 0));
735+ return status::unimplemented;
736+ }
737+
738+ return status::success;
739+ }
740+
555741status_t init_conf_dw (acl_conv_conf_t &acp, memory_desc_t &src_md,
556742 memory_desc_t &weights_md, memory_desc_t &dst_md,
557743 memory_desc_t &bias_md, const convolution_desc_t &cd,
@@ -577,6 +763,66 @@ status_t init_conf_dw(acl_conv_conf_t &acp, memory_desc_t &src_md,
577763 return status::success;
578764}
579765
766+ status_t init_conf_wino (acl_conv_conf_t &acp, memory_desc_t &src_md,
767+ memory_desc_t &weights_md, memory_desc_t &dst_md,
768+ memory_desc_t &bias_md, const convolution_desc_t &cd,
769+ const primitive_attr_t &attr) {
770+
771+ // Under these conditions, fallback to faster GEMM-based convolution
772+ // unless the user explicitly specifies Winograd algorithm
773+ // clang-format off
774+ if (one_of (true , src_md.dims [2 ] > 112 , // ih
775+ src_md.dims [3 ] > 112 , // iw
776+ src_md.dims [1 ] < 64 , // ic
777+ dst_md.dims [1 ] < 64 , // oc
778+ dnnl_get_max_threads () > 28 )
779+ && cd.alg_kind == alg_kind::convolution_auto) {
780+ return status::unimplemented;
781+ }
782+ // clang-format on
783+
784+ // General Compute Library checks, memory tags are also set there
785+ CHECK (acl_init_conf_wino (acp, src_md, weights_md, dst_md, bias_md, cd, attr));
786+
787+ const bool shape_ok
788+ // only unit strides allowed
789+ = (acp.padstride_info .stride () == std::pair<uint, uint> {1 , 1 })
790+ // Note: Compute Library supports arbitrary padding for wino kernels
791+ // but we only allow small padding to be consistent with oneDNN
792+ && (acp.padstride_info .pad ().first <= 1 ) // padding left/right
793+ && (acp.padstride_info .pad ().second <= 1 ) // padding top/bottom
794+ // only non-dilated convolutions allowed
795+ && (acp.dilation_info == arm_compute::Size2D (1 , 1 ));
796+
797+ ACL_CHECK_SUPPORT (!shape_ok, " shape not supported by winograd kernels" );
798+
799+ if (arm_compute::ConvolutionMethod::WINOGRAD !=
800+ arm_compute::NEConvolutionLayer::get_convolution_method (&acp.src_tensor_info ,
801+ &acp.wei_tensor_info ,
802+ // acp.with_bias ? &acp.bia_tensor_info : nullptr,
803+ &acp.dst_tensor_info ,
804+ acp.padstride_info ,
805+ arm_compute::WeightsInfo (),
806+ acp.dilation_info ,
807+ acp.act_info ,
808+ true )) {
809+ return status::unimplemented;
810+ }
811+ // clang-format off
812+ // Validate convolution manually to check for return status
813+ ACL_CHECK_VALID (arm_compute::NEWinogradConvolutionLayer::validate (
814+ &acp.src_tensor_info ,
815+ &acp.wei_tensor_info ,
816+ acp.with_bias ? &acp.bia_tensor_info : nullptr ,
817+ &acp.dst_tensor_info ,
818+ acp.padstride_info ,
819+ acp.act_info ,
820+ true )); // enable_fast_math flag in ACL Winograd
821+ // clang-format on
822+
823+ return status::success;
824+ }
825+
580826} // namespace acl_convolution_utils
581827
582828} // namespace acl
0 commit comments