From f6baaf820da242e1b63b33ed6a9d901a9280fbc4 Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Thu, 13 Jun 2024 02:29:52 +0000 Subject: [PATCH 1/2] reimpl nll loss v2 for ascend --- impl/ascend/aclnn/adaptor.hpp | 2 +- impl/ascend/functions/nlllossv2.cpp | 100 ++++++++++++++++++++++++++++ impl/ascend_npu/CMakeLists.txt | 1 + impl/ascend_npu/ascend_config.yaml | 4 +- 4 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 impl/ascend/functions/nlllossv2.cpp diff --git a/impl/ascend/aclnn/adaptor.hpp b/impl/ascend/aclnn/adaptor.hpp index 484c10527..1bb5adca1 100644 --- a/impl/ascend/aclnn/adaptor.hpp +++ b/impl/ascend/aclnn/adaptor.hpp @@ -88,7 +88,7 @@ inline aclTensor* createAclTensorFromAscendTensor(const AscendTensor& input) { input.getAclDataType(), stride.data(), input.storageOffset(), - format, // input.getAclDataFormat(), // TODO(lljbash): op_plugin assume non-channel-last, why? + format, &storageSize, /*storageDimsNum=*/1, const_cast(storagePtr)); diff --git a/impl/ascend/functions/nlllossv2.cpp b/impl/ascend/functions/nlllossv2.cpp new file mode 100644 index 000000000..3156063ee --- /dev/null +++ b/impl/ascend/functions/nlllossv2.cpp @@ -0,0 +1,100 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2024, DeepLink. + */ + +#include "../aclnn/acl_scalar.hpp" +#include "../aclnn/adaptor.hpp" + +namespace impl { +namespace ascend { +diopiError_t diopiNLLLossV2(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t totalWeight, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex) { + if (input == nullptr) { + return diopiSuccess; + } + + AscendTensor inputAt(input); + if (inputAt.numel() <= 0) { + if (diopiReduction_t::ReductionMean == reduction) { + DIOPI_ASCEND_CALL_ACLNN(aclnnInpalceFillScalar, ctx, out, std::nanf("")); + } else if (diopiReduction_t::ReductionSum == reduction || diopiReduction_t::ReductionNone == reduction) { + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceZero, ctx, out); + } + return diopiSuccess; + } + + diopiTensorHandle_t weightTmp = const_cast(weight); + if (weightTmp == nullptr) { + const int64_t channel = inputAt.dim() >= 4 ? inputAt.shape(1) : inputAt.shape(-1); + std::vector weightSize{channel}; + diopiSize_t weightShape = vectorToDiopiSize(weightSize); + diopiRequireTensor(ctx, &weightTmp, &weightShape, nullptr, inputAt.dtype(), diopi_device); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, weightTmp); + } + + if (inputAt.dim() <= 2) { + DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss, ctx, input, target, weightTmp, reduction, ignoreIndex, out, totalWeight); + } else if (inputAt.dim() == 4) { + DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2d, ctx, input, target, weightTmp, reduction, ignoreIndex, out, totalWeight); + } else { + AscendTensor outAt(out); + AscendTensor targetAt(target); + AscendTensor inputView = inputAt.view({inputAt.shape(0), inputAt.shape(1), inputAt.numel() / inputAt.shape(0) / inputAt.shape(1), 1}); + AscendTensor outView = (outAt.numel() > 1) ? outAt.view({outAt.shape(0), outAt.numel() / outAt.shape(0), 1}) : outAt; + AscendTensor targetView = targetAt.view({targetAt.shape(0), targetAt.numel() / targetAt.shape(0), 1}); + } + + return diopiSuccess; +} + +diopiError_t diopiNLLLossV2Backward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiConstTensorHandle_t gradOutput, + diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t totalWeight, diopiReduction_t reduction, int64_t ignoreIndex) { + AscendTensor inputAt(input); + AscendTensor gradInputAt(gradInput); + if (input == nullptr || gradInput == nullptr || inputAt.numel() <= 0 || gradInputAt.numel() <= 0) { + return diopiSuccess; + } + /* + * A tensor representing the sum of weights for each element considered in the NLL loss computation. + * In case a weight tensor is provided, total_weight represents the sum of weights for all the non-ignored indices in the target tensor. + * When no weight tensor is provided, total_weight corresponds to the count of all non-ignored indices. + */ + diopiTensorHandle_t weightTmp = const_cast(weight); + if (weightTmp == nullptr) { + const int64_t channel = inputAt.dim() >= 4 ? inputAt.shape(1) : inputAt.shape(-1); + std::vector weightSize{channel}; + diopiSize_t weightShape = vectorToDiopiSize(weightSize); + diopiRequireTensor(ctx, &weightTmp, &weightShape, nullptr, inputAt.dtype(), diopi_device); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, weightTmp); + } + + if (inputAt.dim() <= 2) { + DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLossBackward, ctx, gradOutput, input, target, weightTmp, reduction, ignoreIndex, totalWeight, gradInput); + } else if (inputAt.dim() == 4) { + DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2dBackward, ctx, gradOutput, input, target, weightTmp, reduction, ignoreIndex, totalWeight, gradInput); + } else { + AscendTensor gradIputAt(gradInput); + AscendTensor gradOutputAt(gradOutput); + AscendTensor targetAt(target); + + AscendTensor inputView = inputAt.view({inputAt.shape(0), inputAt.shape(1), inputAt.numel() / inputAt.shape(0) / inputAt.shape(1), 1}); + AscendTensor gradInputView = + gradInputAt.view({gradInputAt.shape(0), gradInputAt.shape(1), gradInputAt.numel() / gradInputAt.shape(0) / gradInputAt.shape(1), 1}); + AscendTensor gradOutputView; + if (gradOutputAt.numel() > 1) { + gradOutputView.view({gradOutputAt.shape(0), gradOutputAt.numel() / gradOutputAt.shape(0), 1}); + } else { + gradOutputView = gradOutputAt; + } + AscendTensor targetView = targetAt.view({targetAt.shape(0), targetAt.numel() / targetAt.shape(0), 1}); + DIOPI_ASCEND_CALL_ACLNN( + aclnnNLLLoss2dBackward, ctx, gradOutputView, inputView, targetView, weightTmp, reduction, ignoreIndex, totalWeight, gradInputView); + } + return diopiSuccess; +} + +} // namespace ascend +} // namespace impl diff --git a/impl/ascend_npu/CMakeLists.txt b/impl/ascend_npu/CMakeLists.txt index 10c036847..c0354c443 100755 --- a/impl/ascend_npu/CMakeLists.txt +++ b/impl/ascend_npu/CMakeLists.txt @@ -197,6 +197,7 @@ set(OLD_IMPL_SRC ${OLD_IMPL_DIR}/functions/matmul.cpp ${OLD_IMPL_DIR}/functions/max_pool2d.cpp ${OLD_IMPL_DIR}/functions/equal.cpp + ${OLD_IMPL_DIR}/functions/nlllossv2.cpp ${OLD_IMPL_DIR}/functions_mmcv/roi_align_npu.cpp ${OLD_IMPL_DIR}/functions_ext/rms_norm.cpp ${OLD_IMPL_DIR}/functions_ext/adamw.cpp diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index 32d9007cd..4e61809f2 100755 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -162,6 +162,8 @@ ascend: - diopiNeInp - diopiNeInpScalar - diopiNeScalar +- diopiNLLLossV2 +- diopiNLLLossV2Backward - diopiNorm - diopiNormal - diopiNormalInp @@ -261,8 +263,6 @@ ascend_npu: - diopiMm - diopiNLLLoss - diopiNLLLossBackward -- diopiNLLLossV2 -- diopiNLLLossV2Backward - diopiFlashAttention - diopiFlashAttentionBackward - diopiFlashAttentionV2 From bc2b071429ed93b6673aed4a3d6615cb59da19b6 Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Fri, 14 Jun 2024 02:23:09 +0000 Subject: [PATCH 2/2] fix fill inplace call for aclnn --- impl/ascend/functions/nlllossv2.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/impl/ascend/functions/nlllossv2.cpp b/impl/ascend/functions/nlllossv2.cpp index 3156063ee..87e9f02db 100644 --- a/impl/ascend/functions/nlllossv2.cpp +++ b/impl/ascend/functions/nlllossv2.cpp @@ -18,7 +18,8 @@ diopiError_t diopiNLLLossV2(diopiContextHandle_t ctx, diopiTensorHandle_t out, d AscendTensor inputAt(input); if (inputAt.numel() <= 0) { if (diopiReduction_t::ReductionMean == reduction) { - DIOPI_ASCEND_CALL_ACLNN(aclnnInpalceFillScalar, ctx, out, std::nanf("")); + diopiScalar_t nans{diopi_dtype_float64, std::nanf("")}; + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceFillScalar, ctx, out, &nans); } else if (diopiReduction_t::ReductionSum == reduction || diopiReduction_t::ReductionNone == reduction) { DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceZero, ctx, out); } @@ -44,6 +45,7 @@ diopiError_t diopiNLLLossV2(diopiContextHandle_t ctx, diopiTensorHandle_t out, d AscendTensor inputView = inputAt.view({inputAt.shape(0), inputAt.shape(1), inputAt.numel() / inputAt.shape(0) / inputAt.shape(1), 1}); AscendTensor outView = (outAt.numel() > 1) ? outAt.view({outAt.shape(0), outAt.numel() / outAt.shape(0), 1}) : outAt; AscendTensor targetView = targetAt.view({targetAt.shape(0), targetAt.numel() / targetAt.shape(0), 1}); + DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2d, ctx, inputView, targetView, weightTmp, reduction, ignoreIndex, outView, totalWeight); } return diopiSuccess; @@ -85,7 +87,7 @@ diopiError_t diopiNLLLossV2Backward(diopiContextHandle_t ctx, diopiTensorHandle_ gradInputAt.view({gradInputAt.shape(0), gradInputAt.shape(1), gradInputAt.numel() / gradInputAt.shape(0) / gradInputAt.shape(1), 1}); AscendTensor gradOutputView; if (gradOutputAt.numel() > 1) { - gradOutputView.view({gradOutputAt.shape(0), gradOutputAt.numel() / gradOutputAt.shape(0), 1}); + gradOutputView = gradOutputAt.view({gradOutputAt.shape(0), gradOutputAt.numel() / gradOutputAt.shape(0), 1}); } else { gradOutputView = gradOutputAt; }