@@ -4183,8 +4183,66 @@ diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_
41834183 return diopiSuccess;
41844184}
41854185
4186+ diopiError_t diopiGroupNormGB (diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
4187+ diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
4188+ double eps, diopiSize_t reduced_axes, const int64_t channel_axis) {
4189+ impl::aten::setCurStream (ctx);
4190+ auto atInput = impl::aten::buildATen (input);
4191+ auto axisSize = atInput.size (channel_axis);
4192+ auto k = axisSize / num_groups;
4193+ at::IntArrayRef atReducedAxes = impl::aten::buildAtIntArray (reduced_axes);
4194+ std::vector<int64_t > dims;
4195+ int64_t N = 1 ;
4196+ for (int i = 0 ; i < atInput.dim (); i++) {
4197+ if (i == channel_axis) {
4198+ continue ;
4199+ } else {
4200+ bool is_reduced_axis = false ;
4201+ for (int m = 0 ; m < reduced_axes.len ; m++) {
4202+ if (i == reduced_axes.data [m]) {
4203+ is_reduced_axis = true ;
4204+ break ;
4205+ }
4206+ }
4207+ if (is_reduced_axis) {
4208+ continue ;
4209+ } else {
4210+ dims.push_back (i);
4211+ N *= atInput.size (i);
4212+ }
4213+ }
4214+ }
4215+ dims.push_back (channel_axis);
4216+ int64_t HxW = 1 ;
4217+ for (auto i = 0 ; i < reduced_axes.len ; i++) {
4218+ dims.push_back (reduced_axes.data [i]);
4219+ HxW *= atInput.size (reduced_axes.data [i]);
4220+ }
4221+ auto C = atInput.size (channel_axis);
4222+ auto permutedInput = atInput.permute (dims);
4223+ auto permutedShape = permutedInput.sizes ();
4224+ auto reshapedInput = permutedInput.reshape ({N, C, HxW, 1 }).contiguous ();
4225+
4226+ auto atWeight = impl::aten::buildATen (weight);
4227+ auto atBias = impl::aten::buildATen (bias);
4228+ auto atOut = impl::aten::buildATen (out);
4229+ auto atSaveMean = impl::aten::buildATen (save_mean);
4230+ auto atSaveInvstd = impl::aten::buildATen (save_invstd);
4231+
4232+ std::vector<int64_t > reverse_order (dims.size ());
4233+ for (auto i = 0 ; i < atInput.dim (); i++) {
4234+ reverse_order[dims[i]] = i;
4235+ }
4236+ auto tempOut = CALL_ATEN_CUDA_FUNC (native_group_norm, reshapedInput, atWeight, atBias, N, C, HxW, num_groups, eps);
4237+ at::native::copy_ (atOut, std::get<0 >(tempOut).reshape (permutedShape).permute (reverse_order), true );
4238+ at::native::copy_ (atSaveMean, std::get<1 >(tempOut), true );
4239+ at::native::copy_ (atSaveInvstd, std::get<2 >(tempOut), true );
4240+ return diopiSuccess;
4241+ }
4242+
41864243diopiError_t diopiGroupNorm (diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
4187- diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, double eps) {
4244+ diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
4245+ double eps) {
41884246 impl::aten::setCurStream (ctx);
41894247 auto atInput = impl::aten::buildATen (input);
41904248 auto atWeight = impl::aten::buildATen (weight);
0 commit comments