From dc9011eddff5294f8462f13bb78526d49d9b64d6 Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Wed, 28 Aug 2024 16:03:12 +0800 Subject: [PATCH 1/3] Access fused_adamw operator --- dipu/SupportedDiopiFunctions.txt | 1 + .../diopi_functions.yaml | 36 +++++++++++++++++++ dipu/third_party/DIOPI | 2 +- 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/dipu/SupportedDiopiFunctions.txt b/dipu/SupportedDiopiFunctions.txt index 547e75955..0f6a3b0ff 100644 --- a/dipu/SupportedDiopiFunctions.txt +++ b/dipu/SupportedDiopiFunctions.txt @@ -101,6 +101,7 @@ diopiForeachmulInpTensor diopiForeachmulScalar diopiForeachmulTensor diopiForeachnormScalar +diopiFusedAdamW diopiGather diopiGe diopiGeInp diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 2759f7fb6..8e3b03ff0 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -1325,6 +1325,42 @@ ::diopiConstTensorHandle_t self_dtype_diopi = dipu::diopi_helper::toDiopiTensorHandle(self_dtype); interface: diopiProd(ctx, out, self_dtype_diopi, nullptr) +- schema: "_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()" + custom_code_at_the_beginning: | + std::vector diopiTensorHandles_self(self.size()); + for(size_t i=0; i < self.size(); ++i){ + diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(self.at(i)); + diopiTensorHandle_t handle = const_cast(const_handle); + diopiTensorHandles_self[i] = handle; + } + std::vector diopiTensorHandles_grads(grads.size()); + for(size_t i=0; i < grads.size(); ++i){ + diopiTensorHandles_grads[i] = dipu::diopi_helper::toDiopiTensorHandle(grads.at(i)); + } + std::vector diopiTensorHandles_exp_avgs(exp_avgs.size()); + for(size_t i=0; i < exp_avgs.size(); ++i){ + diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avgs.at(i)); + diopiTensorHandle_t handle = const_cast(const_handle); + diopiTensorHandles_exp_avgs[i] = handle; + } + std::vector diopiTensorHandles_exp_avg_sqs(exp_avg_sqs.size()); + for(size_t i=0; i < exp_avg_sqs.size(); ++i){ + diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avg_sqs.at(i)); + diopiTensorHandle_t handle = const_cast(const_handle); + diopiTensorHandles_exp_avg_sqs[i] = handle; + } + std::vector diopiTensorHandles_max_exp_avg_sqs(max_exp_avg_sqs.size()); + for(size_t i=0; i < max_exp_avg_sqs.size(); ++i){ + diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(max_exp_avg_sqs.at(i)); + diopiTensorHandle_t handle = const_cast(const_handle); + diopiTensorHandles_max_exp_avg_sqs[i] = handle; + } + std::vector diopiTensorHandles_state_steps(state_steps.size(), nullptr); + for(size_t i=0; i < state_steps.size(); ++i){ + diopiTensorHandles_state_steps[i] = dipu::diopi_helper::toDiopiTensorHandle(state_steps.at(i)); + } + interface: diopiFusedAdamW(ctx, diopiTensorHandles_self.data(), diopiTensorHandles_grads.data(), diopiTensorHandles_exp_avgs.data(), diopiTensorHandles_exp_avg_sqs.data(), diopiTensorHandles_max_exp_avg_sqs.data(), diopiTensorHandles_state_steps.data(), static_cast(self.size()), lr, beta1, beta2, eps, weight_decay, amsgrad, maximize) + - schema: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) custom_code_at_the_beginning: | const auto self_dtype = at::native::to(self, dtype); diff --git a/dipu/third_party/DIOPI b/dipu/third_party/DIOPI index 65930a539..1d463c252 160000 --- a/dipu/third_party/DIOPI +++ b/dipu/third_party/DIOPI @@ -1 +1 @@ -Subproject commit 65930a539938b692a84ba77027e91686b3d2516d +Subproject commit 1d463c252edfd7b105e7486b5c1d338fddd26766 From 8a45540a5ed266610861105ebeddc97009bd345b Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Tue, 3 Sep 2024 16:12:06 +0800 Subject: [PATCH 2/3] Access fused_adamw operator --- dipu/third_party/DIOPI | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dipu/third_party/DIOPI b/dipu/third_party/DIOPI index 1d463c252..880339063 160000 --- a/dipu/third_party/DIOPI +++ b/dipu/third_party/DIOPI @@ -1 +1 @@ -Subproject commit 1d463c252edfd7b105e7486b5c1d338fddd26766 +Subproject commit 880339063a4daae4925d2535cf02295f97a2d5c9 From de3b463d7067ff11420054dbb1af6fc88f377198 Mon Sep 17 00:00:00 2001 From: zyf654321 Date: Thu, 5 Sep 2024 10:23:34 +0800 Subject: [PATCH 3/3] Add fused_adamw operator on dipu --- dipu/third_party/DIOPI | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dipu/third_party/DIOPI b/dipu/third_party/DIOPI index 880339063..02f03c6ab 160000 --- a/dipu/third_party/DIOPI +++ b/dipu/third_party/DIOPI @@ -1 +1 @@ -Subproject commit 880339063a4daae4925d2535cf02295f97a2d5c9 +Subproject commit 02f03c6abb20aa39d1d978436a53a2e4ec242d65