diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 94b6705bd..bc0acae1a 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -2254,8 +2254,11 @@ diopiError_t diopiConvolution2dBackward(diopiContextHandle_t ctx, diopiTensorHan at::native::copy_(atGradWeight, std::get<1>(tempOut), true); at::native::copy_(atGradBias, std::get<2>(tempOut), true); } else { - auto results = at::convolution_backward( - atGrad, atInput, atWeight, c10::nullopt, atStride, atPadding, atDilation, false, outputPadding, groups, {true, true, false}); + std::array output_mask{true, true, false}; + if (!grad_input) output_mask[0] = false; + if (!grad_weight) output_mask[1] = false; + auto results = at::native::convolution_backward( + atGrad, atInput, atWeight, c10::nullopt, atStride, atPadding, atDilation, false, outputPadding, groups, output_mask); impl::aten::updateATen2Tensor(ctx, std::get<0>(results), grad_input); impl::aten::updateATen2Tensor(ctx, std::get<1>(results), grad_weight); if (bias_sizes && grad_bias) {