@@ -13,79 +13,63 @@ diopiError_t diopiBaddbmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, dio
1313 diopiConstTensorHandle_t batch2, double beta, double alpha) {
1414 diopiDtype_t outDtype;
1515 diopiGetTensorDtype (out, &outDtype);
16- diopiDtype_t execType;
1716
18- // adjust the input's and output's data type
19- if (outDtype == diopi_dtype_float64) {
20- execType = diopi_dtype_float32;
21- } else {
22- execType = outDtype;
23- }
24-
25- AscendTensor inputCopy (input);
26- AscendTensor outputCopy (out);
27- AscendTensor batch1Copy (batch1);
28- AscendTensor batch2Copy (batch2);
29- castTensor (ctx, outputCopy, execType);
30- castTensor (ctx, batch1Copy, execType);
31- castTensor (ctx, inputCopy, execType);
32- castTensor (ctx, batch2Copy, execType);
17+ AscendTensor inputAt (input);
18+ AscendTensor outputAt (out);
19+ AscendTensor batch1At (batch1);
20+ AscendTensor batch2At (batch2);
3321
3422 // get the size of batch1 * batch2
35- AscendTensor asBatch1 = AscendTensor (batch1Copy);
36- AscendTensor asBatch2 = AscendTensor (batch2Copy);
37- std::vector<int64_t > batch1Shape = asBatch1.shape ();
38- std::vector<int64_t > batch2Shape = asBatch2.shape ();
23+ std::vector<int64_t > batch1Shape = batch1At.shape ();
24+ std::vector<int64_t > batch2Shape = batch2At.shape ();
3925 std::vector<int64_t > vectorSizeBatchMatMulTensor = {batch1Shape[0 ], batch1Shape[1 ], batch2Shape[2 ]};
4026
4127 // init a tensor according to the size of batch1 * batch2 ;
4228 diopiSize_t diopiSizeBatchMatMulTensor = vectorToDiopiSize (vectorSizeBatchMatMulTensor);
43- AscendTensor asBatchMatMulTensor ;
44- makeTensor (ctx, asBatchMatMulTensor , &diopiSizeBatchMatMulTensor, execType , diopiDevice_t::diopi_device);
29+ AscendTensor batchMatMulTensorAt ;
30+ makeTensor (ctx, batchMatMulTensorAt , &diopiSizeBatchMatMulTensor, outDtype , diopiDevice_t::diopi_device);
4531
4632 // does batch1/batch2 need to transpose?
4733 bool isSelfT = false ;
4834 bool isMat2T = false ;
4935
5036 // do batch1 times batch2 -> BatchMatMulTensor
5137 AclOpRunner<2 , 1 >(" BatchMatMul" , ctx)
52- .addInput (batch1Copy )
53- .addInput (batch2Copy )
54- .addOutput (asBatchMatMulTensor )
38+ .addInput (batch1At )
39+ .addInput (batch2At )
40+ .addOutput (batchMatMulTensorAt )
5541 .setAttr (" adj_x1" , isSelfT)
5642 .setAttr (" adj_x2" , isMat2T)
5743 .run ();
5844
5945 // init memory based on the size of alphaMulTensor and betaMulTensor
6046 AscendTensor alphaMulTensor;
6147 AscendTensor betaMulTensor;
62- makeTensorLike (ctx, alphaMulTensor, asBatchMatMulTensor, execType );
63- makeTensorLike (ctx, betaMulTensor, inputCopy, execType );
48+ makeTensorLike (ctx, alphaMulTensor, batchMatMulTensorAt, outDtype );
49+ makeTensorLike (ctx, betaMulTensor, inputAt, outDtype );
6450
6551 diopiScalar_t alphaScalar;
66- alphaScalar.stype = execType ;
52+ alphaScalar.stype = outDtype ;
6753 alphaScalar.fval = alpha;
6854 diopiScalar_t betaScalar;
69- betaScalar.stype = execType ;
55+ betaScalar.stype = outDtype ;
7056 betaScalar.fval = beta;
7157
7258 // transform ascendTensor to diopiTensorHandle_t
7359 diopiTensorHandle_t diopiAlphaMulTensor = const_cast <diopiTensorHandle_t>(alphaMulTensor.tensorHandle ());
7460 diopiTensorHandle_t diopiBateMulTensor = const_cast <diopiTensorHandle_t>(betaMulTensor.tensorHandle ());
75- diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast <diopiTensorHandle_t>(asBatchMatMulTensor .tensorHandle ());
76- diopiTensorHandle_t diopiInputCopy = const_cast <diopiTensorHandle_t>(inputCopy .tensorHandle ());
61+ diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast <diopiTensorHandle_t>(batchMatMulTensorAt .tensorHandle ());
62+ diopiTensorHandle_t diopiInput = const_cast <diopiTensorHandle_t>(inputAt .tensorHandle ());
7763
7864 // alpha times BatchMatMulTensor -> alphaMulTensor and beta times input -> betaMulTensor
7965 diopiMulScalar (ctx, diopiAlphaMulTensor, diopiAsBatchMatMulTensor, &alphaScalar);
80- diopiMulScalar (ctx, diopiBateMulTensor, diopiInputCopy , &betaScalar);
66+ diopiMulScalar (ctx, diopiBateMulTensor, diopiInput , &betaScalar);
8167
8268 diopiScalar_t other;
8369 other.fval = 1 ;
8470 other.stype = outDtype;
85- diopiTensorHandle_t diopiOutputCopy = const_cast <diopiTensorHandle_t>(outputCopy.tensorHandle ());
86- diopiAdd (ctx, diopiOutputCopy, diopiAlphaMulTensor, diopiBateMulTensor, &other);
87- diopiCastDtype (ctx, out, diopiOutputCopy);
88-
71+ diopiTensorHandle_t diopiOutput = const_cast <diopiTensorHandle_t>(outputAt.tensorHandle ());
72+ diopiAdd (ctx, diopiOutput, diopiAlphaMulTensor, diopiBateMulTensor, &other);
8973 return diopiSuccess;
9074}
9175
0 commit comments